Coverage for src/thunderlab/multivariateexplorer.py: 88%
981 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 16:02 +0000
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-09 16:02 +0000
1"""Simple GUI for viewing and exploring multivariate data.
3- `class MultiVariateExplorer`: simple matplotlib-based GUI for viewing and exploring multivariate data.
4- `categorize()`: convert categorial string data into integer categories.
5- `select_features()`: assemble list of column indices.
6- `select_coloring()`: select column from data table for colorizing the data.
7- `list_available_features()`: print available features on console.
8"""
10import sys
11import numpy as np
12from scipy.stats import pearsonr
13from sklearn import decomposition
14from sklearn import preprocessing
15import matplotlib.pyplot as plt
16import matplotlib.patches as patches
17import matplotlib.widgets as widgets
18import argparse
19from .version import __version__, __year__
20from .tabledata import TableData
23class MultivariateExplorer(object):
24 """Simple matplotlib-based GUI for viewing and exploring multivariate data.
26 Shown are scatter plots of all pairs of variables or PCA axis.
27 Points in the scatter plots are colored according to the values of one of the variables.
28 Data points can be selected and optionally corresponding waveforms are shown.
30 First you initialize the explorer with the data. Then you optionally
31 specify how to colorize the data and provide waveform data
32 associated with the data. Finally you show the figure:
33 ```
34 expl = MultivariateExplorer(data)
35 expl.set_colors(2)
36 expl.set_wave_data(waveforms, 'Time [s]', 'Sine')
37 expl.show()
38 ```
40 The `compute_pca() function computes a principal component analysis (PCA)
41 on the input data, and `save_pca()` writes the principal components to a file.
43 Customize the appearance and information provided by subclassing
44 MultivariateExplorer and reimplementing the functions
45 - fix_scatter_plot()
46 - fix_waveform_plot()
47 - list_selection()
48 - analyze_selection()
49 See the documentation of these functions for details.
50 """
52 mouse_actions = [
53 ('left click', 'select sample'),
54 ('left and drag', 'rectangular selection of samples and/or zoom'),
55 ('shift + left click/drag', 'add samples to selection'),
56 ('ctrl + left click/drag', 'remove samples from selection')
57 ]
58 """List of tuples with mouse actions and a description of their action."""
60 key_actions = [
61 ('c, C', 'cycle color map trough data columns'),
62 ('p,P', 'toggle between features, PCs, and scaled PCs'),
63 ('<, pageup', 'decrease number of displayed featured/PCs'),
64 ('>, pagedown', 'increase number of displayed features/PCs'),
65 ('o, z', 'toggle zoom mode on or off'),
66 ('backspace', 'zoom back'),
67 ('n, N', 'decrease, increase number of bins of histograms'),
68 ('H', 'toggle between scatter plot and 2D histogram'),
69 ('left, right, up, down', 'show and move magnified scatter plot'),
70 ('escape', 'close magnified scatter plot'),
71 ('ctrl + a', 'select all'),
72 ('+, -', 'increase, decrease pick radius'),
73 ('0', 'reset pick radius'),
74 ('l', 'list selection on console'),
75 ('w', 'toggle maximized waveform plot'),
76 ('h', 'toggle help window'),
77 ]
78 """List of tuples with key shortcuts and a description of their action."""
80 def __init__(self, data, labels=None, title=None):
81 """Initialize explorer with scatter-plot data.
83 Parameters
84 ----------
85 data: TableData, 2D array, or list of 1D arrays
86 The data to be explored. Each column is a variable.
87 For the 2D array the columns are the second dimension,
88 for a list of 1D arrays, the list goes over columns,
89 i.e. each 1D array is one column.
90 labels: list of str
91 If data is not a TableData, then this provides labels
92 for the data columns.
93 title: str
94 Title for the window.
95 """
96 # data. categories and labels:
97 self.raw_data = None # original data table as 2D numpy array (samples x features)
98 self.raw_labels = None # for each feature a label, optional with unit
99 self.categories = [] # for each feature None or list of categories
100 if isinstance(data, TableData):
101 for c, col in enumerate(data):
102 if not isinstance(col[0], (int, float,
103 np.integer, np.floating)):
104 # categorial data:
105 #print(data[:,c])
106 cats, data[:,c] = categorize(col)
107 self.categories.append(cats)
108 else:
109 self.categories.append(None)
110 self.raw_data = data.array()
111 if labels is None:
112 self.raw_labels = []
113 for c in range(len(data)):
114 if len(data.unit(c)) > 0 and not data.unit(c) in ['-', '1']:
115 self.raw_labels.append(f'{data.label(c)} [{data.unit(c)}]')
116 else:
117 self.raw_labels.append(data.label(c))
118 else:
119 self.raw_labels = labels
120 else:
121 if isinstance(data, np.ndarray):
122 self.raw_data = data
123 self.categories = [None] * data.shape[1]
124 else:
125 for c, col in enumerate(data):
126 if not isinstance(col[0], (int, float,
127 np.integer, np.floating)):
128 # categorial data:
129 cats, data[c] = categorize(col)
130 self.categories.append(cats)
131 else:
132 self.categories.append(None)
133 self.raw_data = np.asarray(data).T
134 self.raw_labels = labels
135 # remove columns containing only invalid numbers:
136 cols = np.all(~np.isfinite(self.raw_data), 0)
137 if np.sum(cols) > 0:
138 print('removed columns containing no numbers:',
139 [l for l, c in zip(self.raw_labels, cols) if c])
140 self.raw_data = self.raw_data[:, ~cols]
141 self.raw_labels = [l for l, c in zip(self.raw_labels, cols) if not c]
142 # remove rows containing invalid numbers:
143 self.valid_samples = ~np.any(~np.isfinite(self.raw_data), 1)
144 self.raw_data = self.raw_data[self.valid_samples, :]
145 if np.sum(~self.valid_samples) > 0:
146 print(f'removed {np.sum(~self.valid_samples)} rows containing invalid numbers:')
147 for k in range(len(self.valid_samples)):
148 if not self.valid_samples[k]:
149 print(k)
150 self.valid_rows = [k for k in range(len(self.valid_samples))
151 if self.valid_samples[k]]
152 # title for the window:
153 self.title = title if title is not None else 'MultivariateExplorer'
154 # data, pca-data, scaled-pca data (no pca data yet):
155 self.all_data = [self.raw_data, None, None]
156 self.all_labels = [self.raw_labels, None, None]
157 self.all_maxcols = [self.raw_data.shape[1], None, None]
158 self.all_titles = ['data', 'PCA', 'scaled PCA'] # added to window title
159 # pca:
160 self.pca_tables = [None, None] # raw and scaled pca coefficients
161 self._pca_header(self.raw_data, self.raw_labels) # prepare header of the pca tables
162 # start showing raw data:
163 self.show_mode = 0 # show data, pca or scaled pca
164 self.data = self.all_data[self.show_mode] # the data shown
165 self.labels = self.all_labels[self.show_mode] # the feature labels shown
166 self.maxcols = self.all_maxcols[self.show_mode] # maximum number of features currently shown
167 if self.maxcols > 6:
168 self.maxcols = 6
169 # waveform data:
170 self.wave_data = []
171 self.wave_nested = False
172 self.wave_has_xticks = []
173 self.wave_xlabels = []
174 self.wave_ylabels = []
175 self.wave_title = False
176 # colors:
177 self.color_map = plt.get_cmap('jet')
178 self.extra_colors = None # additional data column to be used for coloring
179 self.extra_color_label = None # label for extra_colors
180 self.extra_categories = None # category name for extra_colors if needed
181 self.color_set_index = 0 # -1: rows and extra_colors, 0: data, 1: pca, 2: scaled pca
182 self.color_index = 0 # column used for coloring with color_set_index
183 self.color_values = None # data column currently used for coloring as specified by color_set_index and color_index
184 self.color_label = None # label of data currently used for coloring
185 self.data_colors = None # actual colors for color_values
186 self.color_vmin = None
187 self.color_vmax = None
188 self.color_ticks = None
189 self.cbax = None # axes of color bar
190 # figure variables:
191 self.plt_params = {}
192 for k in ['toolbar', 'keymap.quit', 'keymap.back', 'keymap.forward',
193 'keymap.zoom', 'keymap.pan', 'keymap.xscale', 'keymap.yscale']:
194 self.plt_params[k] = plt.rcParams[k]
195 if k != 'toolbar':
196 plt.rcParams[k] = ''
197 self.xborder = 100.0 # pixel for ylabels
198 self.yborder = 50.0 # pixel for xlabels
199 self.spacing = 10.0 # pixel between plots
200 self.mborder = 20.0 # pixel around magnified plot
201 self.pick_radius = 4.0
202 # histogram plots:
203 self.hist_ax = [] # list of histogram axes
204 self.hist_indices = [] # feature index of the histogram axes
205 self.hist_selector = [] # for each histogram axes a selector
206 self.hist_nbins = 30 # number of bins for computing histograms
207 # scatter plots:
208 self.scatter_ax = [] # list of axes with scatter plots (1D)
209 self.scatter_indices = [] # for each axes a tuple of the column and row index
210 self.scatter_artists = [] # artists of selected scatter points
211 self.scatter_selector = [] # selector for each axes
212 self.scatter = True # scatter (True) or density (False)
213 self.mark_data = [] # list of indices of selected data
214 self.significance_level = 0.05 # r is bold if p is below
215 self.select_zooms = False
216 self.zoom_stack = []
217 # magnified scatter plot:
218 self.magnified_on = False
219 self.magnified_backdrop = None
220 self.magnified_size = np.array([0.6, 0.6])
221 # waveform plots:
222 self.wave_ax = []
223 # help window:
224 self.help_ax = None
227 def set_wave_data(self, data, xlabels='', ylabels=[], title=False):
228 """Add waveform data to explorer.
230 Parameters
231 ----------
232 data: list of (list of) 2D arrays
233 Waveform data associated with each row of the data.
234 Elements of the outer list correspond to the rows of the data.
235 The inner 2D arrays contain a common x-axes (first column)
236 and one or more corresponding y-values (second and optional higher columns).
237 Each column for y-values is plotted in its own axes on top of each other,
238 from top to bottom.
239 The optional inner list of 2D arrays contains several 2D arrays as ascribed above
240 each with its own common x-axes.
241 xlabel: str or list of str
242 The xlabels for the waveform plots. If only a string is given, then
243 there will be a common xaxis for all the plots, and only the lowest
244 one gets a labeled xaxis. If a list of strings is given, each waveform
245 plot gets its own labeled x-axis.
246 ylabels: list of str
247 The ylabels for each of the waveform plots.
248 title: bool or str
249 If True or a string, povide space on top of the waveform plots for a title.
250 If string, set this as the title for the waveform plots.
251 """
252 self.wave_data = []
253 if data is not None and len(data) > 0:
254 self.wave_data = data
255 self.wave_has_xticks = []
256 self.wave_nested = isinstance(data[0], (list, tuple))
257 if self.wave_nested:
258 for data in self.wave_data[0]:
259 for k in range(data.shape[1]-2):
260 self.wave_has_xticks.append(False)
261 self.wave_has_xticks.append(True)
262 else:
263 for k in range(self.wave_data[0].shape[1]-2):
264 self.wave_has_xticks.append(False)
265 self.wave_has_xticks.append(True)
266 if isinstance(xlabels, (list, tuple)):
267 self.wave_xlabels = xlabels
268 else:
269 self.wave_xlabels = [xlabels]
270 self.wave_ylabels = ylabels
271 self.wave_title = title
272 self.wave_ax = []
275 def set_colors(self, colors=0, color_label=None, color_map=None):
276 """Set data column used to color scatter plots.
278 Parameters
279 ----------
280 colors: int or 1D array
281 Index to colum in data to be used for coloring scatter plots.
282 -2 for coloring row index of data.
283 Or data array used to color scalar plots.
284 color_label: str
285 If colors is an array, this is a label describing the data.
286 It is used to label the color bar.
287 color_map: str or None
288 Name of a matplotlib color map.
289 If None 'jet' is used.
290 """
291 if isinstance(colors, (np.integer, int)):
292 if colors < 0:
293 self.color_set_index = -1
294 self.color_index = 0
295 else:
296 self.color_set_index = 0
297 self.color_index = colors
298 else:
299 if not isinstance(colors[0], (int, float,
300 np.integer, np.floating)):
301 # categorial data:
302 self.extra_categories, self.extra_colors = categorize(colors)
303 else:
304 self.extra_colors = colors
305 self.extra_colors = self.extra_colors[self.valid_samples]
306 self.extra_color_label = color_label
307 self.color_set_index = -1
308 self.color_index = 1
309 self.color_map = plt.get_cmap(color_map if color_map else 'jet')
312 def show(self, ioff=True):
313 """Show interactive scatter plots for exploration.
314 """
315 if ioff:
316 plt.ioff()
317 else:
318 plt.ion()
319 plt.rcParams['toolbar'] = 'None'
320 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q'
321 plt.rcParams['font.size'] = 12
322 self.fig = plt.figure(facecolor='white', figsize=(10, 8))
323 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode])
324 self.fig.canvas.mpl_connect('key_press_event', self._on_key)
325 self.fig.canvas.mpl_connect('resize_event', self._on_resize)
326 self.fig.canvas.mpl_connect('pick_event', self._on_pick)
327 if self.color_map is None:
328 self.color_map = plt.get_cmap('jet')
329 self._set_color_column()
330 self._init_hist_plots()
331 self._init_scatter_plots()
332 self.wave_ax = []
333 if self.wave_data is not None and len(self.wave_data) > 0:
334 axx = None
335 xi = 0
336 for k, has_xticks in enumerate(self.wave_has_xticks):
337 ax = self.fig.add_subplot(1, len(self.wave_has_xticks),
338 1+k, sharex=axx)
339 self.wave_ax.append(ax)
340 if has_xticks:
341 if xi >= len(self.wave_xlabels):
342 self.wave_xlabels.append('')
343 ax.set_xlabel(self.wave_xlabels[xi])
344 xi += 1
345 axx = None
346 else:
347 #ax.xaxis.set_major_formatter(plt.NullFormatter())
348 if axx is None:
349 axx = ax
350 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels):
351 ax.set_ylabel(ylabel)
352 if not isinstance(self.wave_title, bool) and self.wave_title:
353 self.wave_ax[0].set_title(self.wave_title)
354 self.fix_waveform_plot(self.wave_ax, self.mark_data)
355 self._plot_magnified_scatter()
356 self._plot_help()
357 plt.show()
360 def _pca_header(self, data, labels):
361 """Set up header for the table of principal components.
363 Parameters
364 ----------
365 data: ndarray of float
366 The data (samples x features) without invalid (infinite or
367 NaN) numbers.
368 labels: list of str
369 Labels of the features.
370 """
371 lbs = []
372 for l, d in zip(labels, data):
373 if '[' in l:
374 lbs.append(l.split('[')[0].strip())
375 elif '/' in l:
376 lbs.append(l.split('/')[0].strip())
377 else:
378 lbs.append(l)
379 header = TableData(header=lbs)
380 header.set_formats('%.3f')
381 header.insert(0, ['PC'] + ['-']*header.nsecs, '', '%d')
382 header.insert(1, 'variance', '%', '%.3f')
383 for k in range(len(self.pca_tables)):
384 self.pca_tables[k] = TableData(header)
387 def compute_pca(self, scale=False, write=False):
388 """Compute PCA based on the data.
390 Parameters
391 ----------
392 scale: boolean
393 If True standardize data before computing PCA, i.e. remove mean
394 of each variabel and divide by its standard deviation.
395 write: boolean
396 If True write PCA components to standard out.
397 """
398 # pca:
399 pca = decomposition.PCA()
400 if scale:
401 scaler = preprocessing.StandardScaler()
402 scaler.fit(self.raw_data)
403 pca.fit(scaler.transform(self.raw_data))
404 pca_label = 'sPC'
405 else:
406 pca.fit(self.raw_data)
407 pca_label = 'PC'
408 for k in range(len(pca.components_)):
409 if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]):
410 pca.components_[k] *= -1.0
411 pca_data = pca.transform(self.raw_data)
412 pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%'))
413 for k, v in enumerate(pca.explained_variance_ratio_)]
414 if np.min(pca.explained_variance_ratio_) >= 0.01:
415 pca_maxcols = pca_data.shape[1]
416 else:
417 pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01)
418 if pca_maxcols < 2:
419 pca_maxcols = 2
420 if pca_maxcols > 6:
421 pca_maxcols = 6
422 # table with PCA feature weights:
423 pca_table = self.pca_tables[1] if scale else self.pca_tables[0]
424 pca_table.clear_data()
425 pca_table.set_section(pca_label, 0, pca_table.nsecs)
426 for k, comp in enumerate(pca.components_):
427 pca_table.append_data(k+1, 0)
428 pca_table.append_data(100.0*pca.explained_variance_ratio_[k])
429 pca_table.append_data(comp)
430 if write:
431 pca_table.write(table_format='out', unit_style='none')
432 # submit data:
433 if scale:
434 self.all_data[2] = pca_data
435 self.all_labels[2] = pca_labels
436 self.all_maxcols[2] = pca_maxcols
437 else:
438 self.all_data[1] = pca_data
439 self.all_labels[1] = pca_labels
440 self.all_maxcols[1] = pca_maxcols
443 def save_pca(self, file_name, scale, **kwargs):
444 """Write PCA data to file.
446 Parameters
447 ----------
448 file_name: str
449 Name of ouput file.
450 scale: boolean
451 If True write PCA components of standardized PCA.
452 kwargs: dict
453 Additional parameter for TableData.write()
454 """
455 if scale:
456 pca_file = file_name + '-pcacor'
457 pca_table = self.pca_tables[1]
458 else:
459 pca_file = file_name + '-pcacov'
460 pca_table = self.pca_tables[0]
461 if 'unit_style' in kwargs:
462 del kwargs['unit_style']
463 if 'table_format' in kwargs:
464 pca_table.write(pca_file, unit_style='none', **kwargs)
465 else:
466 pca_file += '.dat'
467 pca_table.write(pca_file, unit_style='none')
470 def _set_color_column(self):
471 """Initialize variables used for colorization of scatter points."""
472 if self.color_set_index == -1:
473 if self.color_index == 0:
474 self.color_values = np.arange(self.data.shape[0], dtype=float)
475 self.color_label = 'sample'
476 elif self.color_index == 1:
477 self.color_values = self.extra_colors
478 self.color_label = self.extra_color_label
479 else:
480 self.color_values = self.all_data[self.color_set_index][:,self.color_index]
481 self.color_label = self.all_labels[self.color_set_index][self.color_index]
482 self.color_vmin, self.color_vmax, self.color_ticks = \
483 self.fix_scatter_plot(self.cbax, self.color_values,
484 self.color_label, 'c')
485 if self.color_ticks is None:
486 if self.color_set_index == 0 and \
487 self.categories[self.color_index] is not None:
488 self.color_ticks = np.arange(len(self.categories[self.color_index]))
489 elif self.color_set_index == -1 and \
490 self.color_index == 1 and \
491 self.extra_categories is not None:
492 self.color_ticks = np.arange(len(self.extra_categories))
493 self.data_colors = self.color_map((self.color_values - self.color_vmin)/(self.color_vmax - self.color_vmin))
496 def _add_backdrop(self, ax):
497 bbox = ax.get_tightbbox(self.fig.canvas.get_renderer())
498 if bbox is not None:
499 self.magnified_backdrop = \
500 patches.Rectangle((bbox.x0 - self.mborder,
501 bbox.y0 - self.mborder),
502 bbox.width + 2*self.mborder,
503 bbox.height + 2*self.mborder,
504 transform=None, clip_on=False,
505 facecolor='#f7f7f7', edgecolor='none',
506 zorder=-5)
507 ax.add_patch(self.magnified_backdrop)
510 def _create_selector(self, ax):
511 try:
512 selector = \
513 widgets.RectangleSelector(ax, self._on_select,
514 useblit=True, button=1,
515 minspanx=0, minspany=0,
516 spancoords='pixels',
517 props=dict(facecolor='gray',
518 edgecolor='gray',
519 alpha=0.2,
520 fill=True),
521 state_modifier_keys=dict(move='',
522 clear='',
523 square='',
524 center='ctrl'))
525 except TypeError:
526 # old matplotlib:
527 selector = widgets.RectangleSelector(ax, self._on_select,
528 useblit=True, button=1)
529 return selector
532 def _plot_hist(self, ax, magnifiedax):
533 """Plot and label a histogram."""
534 try:
535 idx = self.hist_ax.index(ax)
536 c = self.hist_indices[idx]
537 in_hist = True
538 except ValueError:
539 idx = self.scatter_ax.index(ax)
540 c = self.scatter_indices[idx][0]
541 in_hist = False
542 ax.clear()
543 #ax.relim()
544 #ax.autoscale(True)
545 x = self.data[:,c]
546 ax.hist(x, self.hist_nbins)
547 #ax.autoscale(False)
548 ax.set_xlabel(self.labels[c])
549 ax.xaxis.set_major_locator(plt.AutoLocator())
550 ax.xaxis.set_major_formatter(plt.ScalarFormatter())
551 if self.show_mode == 0:
552 if self.categories[c] is not None:
553 ax.set_xticks(np.arange(len(self.categories[c])))
554 ax.set_xticklabels(self.categories[c])
555 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x')
556 if magnifiedax:
557 ax.text(0.05, 0.9, f'n={len(self.data)}',
558 transform=ax.transAxes, zorder=100)
559 ax.set_ylabel('count')
560 cax = self.hist_ax[self.scatter_indices[-1][0]]
561 ax.set_xlim(cax.get_xlim())
562 else:
563 if c == 0:
564 ax.text(0.05, 0.9, f'n={len(self.data)}',
565 transform=ax.transAxes, zorder=100)
566 ax.set_ylabel('count')
567 else:
568 ax.yaxis.set_major_formatter(plt.NullFormatter())
569 selector = self._create_selector(ax)
570 if in_hist:
571 self.hist_selector[idx] = selector
572 else:
573 self.scatter_selector[idx] = selector
574 self.scatter_artists[idx] = None
575 ax.relim(True)
576 if magnifiedax:
577 self._add_backdrop(ax)
580 def _set_hist_ylim(self):
581 ymax = np.max([ax.get_ylim() for ax in self.hist_ax[:self.maxcols]], 0)[1]
582 for ax in self.hist_ax:
583 ax.set_ylim(0, ymax)
586 def _init_hist_plots(self):
587 """Initial plots of the histograms."""
588 n = self.data.shape[1]
589 self.hist_ax = []
590 for r in range(n):
591 ax = self.fig.add_subplot(n, n, (n-1)*n+r+1)
592 self.hist_ax.append(ax)
593 self.hist_indices.append(r)
594 self.hist_selector.append(None)
595 self._plot_hist(ax, False)
596 self._set_hist_ylim()
599 def _plot_scatter(self, ax, magnifiedax, cax=None):
600 """Plot a scatter plot."""
601 idx = self.scatter_ax.index(ax)
602 c, r = self.scatter_indices[idx]
603 if self.scatter: # scatter plot
604 ax.clear()
605 a = ax.scatter(self.data[:,c], self.data[:,r], s=50,
606 edgecolors='white', linewidths=0.5,
607 picker=self.pick_radius, zorder=10)
608 a.set_facecolor(self.data_colors)
609 pr, pp = pearsonr(self.data[:,c], self.data[:,r])
610 fw = 'bold' if pp < self.significance_level/self.data.shape[1] else 'normal'
611 if pr < 0:
612 ax.text(0.95, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw,
613 transform=ax.transAxes, ha='right', zorder=100)
614 else:
615 ax.text(0.05, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw,
616 transform=ax.transAxes, zorder=100)
617 # color bar:
618 if cax is not None:
619 a = ax.scatter(self.data[:, c], self.data[:, r],
620 c=self.color_values, cmap=self.color_map)
621 self.fig.colorbar(a, cax=cax, ticks=self.color_ticks)
622 a.remove()
623 cax.set_ylabel(self.color_label)
624 self.color_vmin, self.color_vmax, self.color_ticks = \
625 self.fix_scatter_plot(self.cbax, self.color_values,
626 self.color_label, 'c')
627 if self.color_ticks is None:
628 if self.color_set_index == 0 and \
629 self.categories[self.color_index] is not None:
630 cax.set_yticklabels(self.categories[self.color_index])
631 elif self.color_set_index == -1 and \
632 self.color_index == 1 and \
633 self.extra_categories is not None:
634 cax.set_yticklabels(self.extra_categories)
635 else: # histogram
636 if self.show_mode == 0:
637 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x')
638 self.fix_scatter_plot(ax, self.data[:,r], self.labels[r], 'y')
639 axrange = [ax.get_xlim(), ax.get_ylim()]
640 ax.clear()
641 ax.hist2d(self.data[:,c], self.data[:,r], self.hist_nbins,
642 range=axrange, cmap=plt.get_cmap('Greys'))
643 # selected data:
644 a = ax.scatter(self.data[self.mark_data, c],
645 self.data[self.mark_data, r], s=100,
646 edgecolors='black', linewidths=0.5,
647 picker=self.pick_radius, zorder=11)
648 a.set_facecolor(self.data_colors[self.mark_data])
649 self.scatter_artists[idx] = a
650 ax.xaxis.set_major_locator(plt.AutoLocator())
651 ax.yaxis.set_major_locator(plt.AutoLocator())
652 ax.xaxis.set_major_formatter(plt.ScalarFormatter())
653 ax.yaxis.set_major_formatter(plt.ScalarFormatter())
654 if self.show_mode == 0:
655 if self.categories[c] is not None:
656 ax.set_xticks(np.arange(len(self.categories[c])))
657 ax.set_xticklabels(self.categories[c])
658 if self.categories[r] is not None:
659 ax.set_yticks(np.arange(len(self.categories[r])))
660 ax.set_yticklabels(self.categories[r])
661 if magnifiedax:
662 ax.set_xlabel(self.labels[c])
663 ax.set_ylabel(self.labels[r])
664 cax = self.scatter_ax[self.scatter_indices[:-1].index(self.scatter_indices[-1])]
665 ax.set_xlim(cax.get_xlim())
666 ax.set_ylim(cax.get_ylim())
667 else:
668 if c == 0:
669 ax.set_ylabel(self.labels[r])
670 if self.show_mode == 0:
671 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x')
672 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y')
673 if not magnifiedax:
674 ax.xaxis.set_major_formatter(plt.NullFormatter())
675 if c > 0:
676 ax.yaxis.set_major_formatter(plt.NullFormatter())
677 ax.set_xlim(*self.hist_ax[c].get_xlim())
678 ax.set_ylim(*self.hist_ax[r].get_xlim())
679 if magnifiedax:
680 self._add_backdrop(ax)
681 selector = self._create_selector(ax)
682 self.scatter_selector[idx] = selector
683 ax.relim(True)
686 def _init_scatter_plots(self):
687 """Initial plots of scatter plots."""
688 self.cbax = self.fig.add_axes([0.5, 0.5, 0.1, 0.5])
689 cbax = self.cbax
690 n = self.data.shape[1]
691 for r in range(1, n):
692 for c in range(r):
693 ax = self.fig.add_subplot(n, n, (r-1)*n+c+1)
694 self.scatter_ax.append(ax)
695 self.scatter_indices.append([c, r])
696 self.scatter_artists.append(None)
697 self.scatter_selector.append(None)
698 self._plot_scatter(ax, False, cbax)
699 cbax = None
702 def _plot_magnified_scatter(self):
703 """Initial plot of the magnified scatter plot."""
704 ax = self.fig.add_axes([0.5, 0.9, 0.05, 0.05])
705 ax.set_visible(False)
706 self.magnified_on = False
707 c = 0
708 r = 1
709 a = ax.scatter(self.data[:, c], self.data[:, r],
710 s=50, edgecolors='none')
711 a.set_facecolor(self.data_colors)
712 a = ax.scatter(self.data[self.mark_data, c],
713 self.data[self.mark_data, r], s=80)
714 a.set_facecolor(self.data_colors[self.mark_data])
715 ax.set_xlabel(self.labels[c])
716 ax.set_ylabel(self.labels[r])
717 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x')
718 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y')
719 self.scatter_ax.append(ax)
720 self.scatter_indices.append([c, r])
721 self.scatter_artists.append(a)
722 self.scatter_selector.append(None)
725 def _plot_help(self):
726 ax = self.fig.add_subplot(1, 1, 1)
727 ax.set_position([0.02, 0.02, 0.96, 0.96])
728 ax.xaxis.set_major_locator(plt.NullLocator())
729 ax.yaxis.set_major_locator(plt.NullLocator())
730 n = len(self.mouse_actions) + len(self.key_actions) + 4
731 dy = 1/n
732 y = 1 - 1.5*dy
733 ax.text(0.05, y, 'Key shortcuts', transform=ax.transAxes,
734 fontweight='bold')
735 y -= dy
736 for a, d in self.key_actions:
737 ax.text(0.05, y, a, transform=ax.transAxes)
738 ax.text(0.3, y, d, transform=ax.transAxes)
739 y -= dy
740 y -= dy
741 ax.text(0.05, y, 'Mouse actions', transform=ax.transAxes,
742 fontweight='bold')
743 y -= dy
744 for a, d in self.mouse_actions:
745 ax.text(0.05, y, a, transform=ax.transAxes)
746 ax.text(0.3, y, d, transform=ax.transAxes)
747 y -= dy
748 ax.set_visible(False)
749 self.help_ax = ax
752 def fix_scatter_plot(self, ax, data, label, axis):
753 """Customize an axes of a scatter plot.
755 This function is called after a scatter plot has been plotted.
756 Once for the x axes, once for the y axis and once for the color bar.
757 Reimplement this function to set appropriate limits and ticks.
759 Return values are only used for the color bar (`axis='c'`).
760 Otherwise they are ignored.
762 For example, ticks for phase variables can be nicely labeled
763 using the unicode character for pi:
764 ```
765 if 'phase' in label:
766 if axis == 'y':
767 ax.set_ylim(0.0, 2.0*np.pi)
768 ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi))
769 ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0'])
770 ```
772 Parameters
773 ----------
774 ax: matplotlib axes
775 Axes of the scatter plot or color bar to be worked on.
776 data: 1D array
777 Data array of the axes.
778 label: str
779 Label coresponding to the data array.
780 axis: str
781 'x', 'y': set properties of x or y axes of ax.
782 'c': set properies of color bar axes (note that ax can be None!)
783 and return vmin, vmax, and ticks.
785 Returns
786 -------
787 min: float
788 minimum value of color bar axis
789 max: float
790 maximum value of color bar axis
791 ticks: list of float
792 position of ticks for color bar axis
793 """
794 return np.nanmin(data), np.nanmax(data), None
797 def fix_waveform_plot(self, axs, indices):
798 """Customize waveform plots.
800 This function is called once after new data have been plotted
801 into the waveform plots. Reimplement this function to customize
802 these plots. In particular to set axis limits and labels, plot
803 title, etc.
804 You may even open a new figure (with non-blocking `show()`).
806 The following member variables might be usefull:
807 - `self.wave_data`: the full list of waveform data.
808 - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values.
809 - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks.
810 - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True).
811 - `self.wave_ylabels`: for each axis its ylabel
813 For example, you can set the linewidth of all plotted waveforms via:
814 ```
815 for ax in axs:
816 for l in ax.lines:
817 l.set_linewidth(3.0)
818 ```
819 or enable markers to be plotted:
820 ```
821 for ax, yl in zip(axs, self.wave_ylabels):
822 if 'Power' in yl:
823 for l in ax.lines:
824 l.set_marker('.')
825 l.set_markersize(15.0)
826 l.set_markeredgewidth(0.5)
827 l.set_markeredgecolor('k')
828 l.set_markerfacecolor(l.get_color())
829 ```
830 Usefull is to reduce the maximum number of y-ticks:
831 ```
832 axs[0].yaxis.get_major_locator().set_params(nbins=7)
833 ```
834 or
835 ```
836 import matplotlib.ticker as ticker
837 axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4))
838 ```
840 Parameters
841 ----------
842 axs: list of matplotlib axes
843 Axis of the waveform plots to be worked on.
844 indices: list of int
845 Indices of the waveforms that have been selected and plotted.
846 """
847 pass
850 def list_selection(self, indices):
851 """List information about the current selection of data points.
853 This function is called when 'l' is pressed. Reimplement this
854 function, for example, to print some meaningfull information
855 about the current selection of data points on console. You may
856 do, however, whatever you want in this function.
858 Parameters
859 ----------
860 indices: list of int
861 Indices of the data points that have been selected.
862 """
863 print('')
864 print('selected rows in data table:')
865 for i in indices:
866 print(self.valid_rows[i])
869 def analyze_selection(self, index):
870 """Provide further information about a single selected data point.
872 This function is called when a single data item was double
873 clicked. Reimplement this function to provide some further
874 details on this data point. This can be an additional figure
875 window. In this case show it non-blocking:
876 `plt.show(block=False)`
878 Parameters
879 ----------
880 index: int
881 The index of the selected data point.
882 """
883 pass
886 def _set_magnified_pos(self, width, height):
887 """Set position of magnified plot."""
888 if self.magnified_on:
889 xoffs = self.xborder/width
890 yoffs = self.yborder/height
891 if self.scatter_indices[-1][1] < self.data.shape[1]:
892 idx = self.scatter_indices[:-1].index(self.scatter_indices[-1])
893 pos = self.scatter_ax[idx].get_position().get_points()
894 else:
895 pos = self.hist_ax[self.scatter_indices[-1][0]].get_position().get_points()
896 pos[0] = np.mean(pos, 0) - 0.5*self.magnified_size
897 if pos[0][0] < xoffs: pos[0][0] = xoffs
898 if pos[0][1] < yoffs: pos[0][1] = yoffs
899 pos[1] = pos[0] + self.magnified_size
900 if pos[1][0] > 1.0-self.spacing/width: pos[1][0] = 1.0-self.spacing/width
901 if pos[1][1] > 1.0-self.spacing/height: pos[1][1] = 1.0-self.spacing/height
902 pos[0] = pos[1] - self.magnified_size
903 self.scatter_ax[-1].set_position([pos[0][0], pos[0][1],
904 self.magnified_size[0], self.magnified_size[1]])
905 self.scatter_ax[-1].set_visible(True)
906 else:
907 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
908 self.scatter_ax[-1].set_visible(False)
911 def _make_selection(self, ax, key, x0, x1, y0, y1):
912 """Select points from a scatter or histogram plot."""
913 if not key in ['shift', 'control']:
914 self.mark_data = []
915 if ax in self.scatter_ax:
916 axi = self.scatter_ax.index(ax)
917 # from scatter plots:
918 c, r = self.scatter_indices[axi]
919 if r < self.data.shape[1]:
920 # from scatter:
921 for ind, (x, y) in enumerate(zip(self.data[:, c], self.data[:, r])):
922 if x >= x0 and x <= x1 and y >= y0 and y <= y1:
923 if ind in self.mark_data:
924 if key == 'control':
925 self.mark_data.remove(ind)
926 elif key != 'control':
927 self.mark_data.append(ind)
928 else:
929 # from histogram:
930 for ind, x in enumerate(self.data[:, c]):
931 if x >= x0 and x <= x1:
932 if ind in self.mark_data:
933 if key == 'control':
934 self.mark_data.remove(ind)
935 elif key != 'control':
936 self.mark_data.append(ind)
937 elif ax in self.hist_ax:
938 r = self.hist_indices[self.hist_ax.index(ax)]
939 # from histogram:
940 for ind, x in enumerate(self.data[:, r]):
941 if x >= x0 and x <= x1:
942 if ind in self.mark_data:
943 if key == 'control':
944 self.mark_data.remove(ind)
945 elif key != 'control':
946 self.mark_data.append(ind)
949 def _update_selection(self):
950 """Highlight selected points in the scatter plots and plot corresponding waveforms."""
951 # update scatter plots:
952 for artist, (c, r) in zip(self.scatter_artists, self.scatter_indices):
953 if artist is not None:
954 if len(self.mark_data) == 0:
955 artist.set_offsets(np.zeros((0, 2)))
956 else:
957 artist.set_offsets(list(zip(self.data[self.mark_data, c],
958 self.data[self.mark_data, r])))
959 artist.set_facecolors(self.data_colors[self.mark_data])
960 # waveform plots:
961 if len(self.wave_ax) > 0:
962 axdi = 0
963 axti = 1
964 for xi, ax in enumerate(self.wave_ax):
965 ax.clear()
966 if len(self.mark_data) > 0:
967 for idx in self.mark_data:
968 if self.wave_nested:
969 data = self.wave_data[idx][axdi]
970 else:
971 data = self.wave_data[idx]
972 if data is not None:
973 ax.plot(data[:, 0], data[:, axti],
974 c=self.data_colors[idx],
975 picker=self.pick_radius)
976 axti += 1
977 if self.wave_has_xticks[xi]:
978 ax.set_xlabel(self.wave_xlabels[axdi])
979 axti = 1
980 axdi += 1
981 #else:
982 # ax.xaxis.set_major_formatter(plt.NullFormatter())
983 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels):
984 ax.set_ylabel(ylabel)
985 if not isinstance(self.wave_title, bool) and self.wave_title:
986 self.wave_ax[0].set_title(self.wave_title)
987 self.fix_waveform_plot(self.wave_ax, self.mark_data)
988 self.fig.canvas.draw()
991 def _set_limits(self, ax, x0, x1, y0, y1):
992 if ax in self.hist_ax:
993 ax.set_xlim(x0, x1)
994 for hax in self.hist_ax:
995 hax.set_ylim(y0, y1)
996 cc = self.hist_indices[self.hist_ax.index(ax)]
997 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices):
998 if c == cc:
999 sax.set_xlim(x0, x1)
1000 if r == cc:
1001 sax.set_ylim(x0, x1)
1002 if ax in self.scatter_ax:
1003 idx = self.scatter_ax.index(ax)
1004 cc, rr = self.scatter_indices[idx]
1005 self.hist_ax[cc].set_xlim(x0, x1)
1006 self.hist_ax[rr].set_xlim(y0, y1)
1007 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices):
1008 if c == cc:
1009 sax.set_xlim(x0, x1)
1010 if c == rr:
1011 sax.set_xlim(y0, y1)
1012 if r == cc:
1013 sax.set_ylim(x0, x1)
1014 if r == rr:
1015 sax.set_ylim(y0, y1)
1018 def _on_key(self, event):
1019 """Handle key events."""
1020 #print('pressed', event.key)
1021 if event.key in ['left', 'right', 'up', 'down']:
1022 if self.magnified_on:
1023 mc, mr = self.scatter_indices[-1]
1024 if event.key == 'left':
1025 if mc > 0:
1026 self.scatter_indices[-1][0] -= 1
1027 elif mr > 1:
1028 if mr >= self.data.shape[1]:
1029 self.scatter_indices[-1][1] = self.maxcols - 1
1030 else:
1031 self.scatter_indices[-1][1] -= 1
1032 self.scatter_indices[-1][0] = self.scatter_indices[-1][1] - 1
1033 else:
1034 self.scatter_indices[-1][0] = self.data.shape[1] - 1
1035 self.scatter_indices[-1][1] = self.data.shape[1]
1036 elif event.key == 'right':
1037 if mc < mr - 1 and mc < self.maxcols - 1:
1038 self.scatter_indices[-1][0] += 1
1039 elif mr < self.maxcols:
1040 self.scatter_indices[-1][0] = 0
1041 self.scatter_indices[-1][1] += 1
1042 if mr >= self.maxcols:
1043 self.scatter_indices[-1][1] = self.data.shape[1]
1044 else:
1045 self.scatter_indices[-1][0] = 0
1046 self.scatter_indices[-1][1] = 1
1047 elif event.key == 'up':
1048 if mr > mc + 1:
1049 if mr >= self.data.shape[1]:
1050 self.scatter_indices[-1][1] = self.maxcols - 1
1051 else:
1052 self.scatter_indices[-1][1] -= 1
1053 elif mc > 0:
1054 self.scatter_indices[-1][0] -= 1
1055 self.scatter_indices[-1][1] = self.data.shape[1]
1056 else:
1057 self.scatter_indices[-1][0] = self.data.shape[1] - 1
1058 self.scatter_indices[-1][1] = self.data.shape[1]
1059 elif event.key == 'down':
1060 if mr < self.maxcols:
1061 self.scatter_indices[-1][1] += 1
1062 if mr >= self.maxcols:
1063 self.scatter_indices[-1][1] = self.data.shape[1]
1064 elif mc < self.maxcols - 1:
1065 self.scatter_indices[-1][0] += 1
1066 self.scatter_indices[-1][1] = mc + 2
1067 if self.scatter_indices[-1][1] >= self.maxcols:
1068 self.scatter_indices[-1][1] = self.data.shape[1]
1069 else:
1070 self.scatter_indices[-1][0] = 0
1071 self.scatter_indices[-1][1] = 1
1072 for k in reversed(range(len(self.zoom_stack))):
1073 if self.zoom_stack[k][0] == self.scatter_ax[-1]:
1074 del self.zoom_stack[k]
1075 self.scatter_ax[-1].clear()
1076 self.scatter_ax[-1].set_visible(True)
1077 self.magnified_on = True
1078 self._set_magnified_pos(self.fig.get_window_extent().width,
1079 self.fig.get_window_extent().height)
1080 if self.scatter_indices[-1][1] < self.data.shape[1]:
1081 self._plot_scatter(self.scatter_ax[-1], True)
1082 else:
1083 self._plot_hist(self.scatter_ax[-1], True)
1084 self.fig.canvas.draw()
1085 else:
1086 if event.key == 'escape':
1087 if len(self.scatter_ax) >= self.data.shape[1]:
1088 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1089 self.magnified_on = False
1090 self.scatter_ax[-1].set_visible(False)
1091 self.fig.canvas.draw()
1092 elif event.key in 'oz':
1093 self.select_zooms = not self.select_zooms
1094 elif event.key == 'backspace':
1095 if len(self.zoom_stack) > 0:
1096 self._set_limits(*self.zoom_stack.pop())
1097 self.fig.canvas.draw()
1098 elif event.key in '+=':
1099 self.pick_radius *= 1.5
1100 elif event.key in '-':
1101 if self.pick_radius > 5.0:
1102 self.pick_radius /= 1.5
1103 elif event.key in '0':
1104 self.pick_radius = 4.0
1105 elif event.key in ['pageup', 'pagedown', '<', '>']:
1106 if event.key in ['pageup', '<'] and self.maxcols > 2:
1107 self.maxcols -= 1
1108 elif event.key in ['pagedown', '>'] and self.maxcols < self.raw_data.shape[1]:
1109 self.maxcols += 1
1110 for ax in self.hist_ax:
1111 self._plot_hist(ax, False)
1112 self._update_layout()
1113 elif event.key == 'w':
1114 if len(self.wave_data) > 0:
1115 if self.maxcols > 0:
1116 self.all_maxcols[self.show_mode] = self.maxcols
1117 self.maxcols = 0
1118 else:
1119 self.maxcols = self.all_maxcols[self.show_mode]
1120 self._set_layout(self.fig.get_window_extent().width,
1121 self.fig.get_window_extent().height)
1122 self.fig.canvas.draw()
1123 elif event.key == 'ctrl+a':
1124 self.mark_data = list(range(len(self.data)))
1125 self._update_selection()
1126 elif event.key in 'cC':
1127 if event.key in 'c':
1128 self.color_index -= 1
1129 if self.color_index < 0:
1130 self.color_set_index -= 1
1131 if self.color_set_index < -1:
1132 self.color_set_index = len(self.all_data)-1
1133 if self.color_set_index >= 0:
1134 if self.all_data[self.color_set_index] is None:
1135 self.compute_pca(self.color_set_index>1, True)
1136 self.color_index = self.all_data[self.color_set_index].shape[1]-1
1137 else:
1138 self.color_index = 0 if self.extra_colors is None else 1
1139 self._set_color_column()
1140 else:
1141 self.color_index += 1
1142 if (self.color_set_index >= 0 and \
1143 self.color_index >= self.all_data[self.color_set_index].shape[1]) or \
1144 (self.color_set_index < 0 and \
1145 self.color_index >= (1 if self.extra_colors is None else 2)):
1146 self.color_index = 0
1147 self.color_set_index += 1
1148 if self.color_set_index >= len(self.all_data):
1149 self.color_set_index = -1
1150 elif self.all_data[self.color_set_index] is None:
1151 self.compute_pca(self.color_set_index>1, True)
1152 self._set_color_column()
1153 for ax in self.scatter_ax:
1154 ax.collections[0].set_facecolors(self.data_colors)
1155 for a in self.scatter_artists:
1156 if a is not None:
1157 a.set_facecolors(self.data_colors[self.mark_data])
1158 for ax in self.wave_ax:
1159 for l, c in zip(ax.lines, self.data_colors[self.mark_data]):
1160 l.set_color(c)
1161 l.set_markerfacecolor(c)
1162 self._plot_scatter(self.scatter_ax[0], False, self.cbax)
1163 self.fix_scatter_plot(self.cbax, self.color_values,
1164 self.color_label, 'c')
1165 self.fig.canvas.draw()
1166 elif event.key in 'nN':
1167 if event.key in 'N':
1168 self.hist_nbins = (self.hist_nbins*3)//2
1169 elif self.hist_nbins >= 15:
1170 self.hist_nbins = (self.hist_nbins*2)//3
1171 else:
1172 self.hist_nbins = 10
1173 for ax in self.hist_ax:
1174 self._plot_hist(ax, False)
1175 self._set_hist_ylim()
1176 if self.scatter_indices[-1][1] >= self.data.shape[1]:
1177 self._plot_hist(self.scatter_ax[-1], True, True)
1178 elif not self.scatter:
1179 self._plot_scatter(self.scatter_ax[-1], True)
1180 if not self.scatter:
1181 for ax in self.scatter_ax[:-1]:
1182 self._plot_scatter(ax, False)
1183 self.fig.canvas.draw()
1184 elif event.key in 'H':
1185 self.scatter = not self.scatter
1186 for ax in self.scatter_ax[:-1]:
1187 self._plot_scatter(ax, False)
1188 if self.scatter_indices[-1][1] < self.data.shape[1]:
1189 self._plot_scatter(self.scatter_ax[-1], True)
1190 self.fig.canvas.draw()
1191 elif event.key in 'pP':
1192 if len(self.scatter_ax) >= self.data.shape[1]:
1193 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1194 self.scatter_indices[-1] = [0, 1]
1195 self.magnified_on = False
1196 self.scatter_ax[-1].set_visible(False)
1197 self.all_maxcols[self.show_mode] = self.maxcols
1198 if event.key == 'P':
1199 self.show_mode += 1
1200 if self.show_mode >= len(self.all_data):
1201 self.show_mode = 0
1202 else:
1203 self.show_mode -= 1
1204 if self.show_mode < 0:
1205 self.show_mode = len(self.all_data)-1
1206 if self.show_mode == 1:
1207 print('principal components')
1208 elif self.show_mode == 2:
1209 print('scaled principal components')
1210 else:
1211 print('data')
1212 if self.all_data[self.show_mode] is None:
1213 self.compute_pca(self.show_mode>1, True)
1214 self.data = self.all_data[self.show_mode]
1215 self.labels = self.all_labels[self.show_mode]
1216 self.maxcols = self.all_maxcols[self.show_mode]
1217 self.zoom_stack = []
1218 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode])
1219 for ax in self.hist_ax:
1220 self._plot_hist(ax, False)
1221 self._set_hist_ylim()
1222 for ax in self.scatter_ax:
1223 self._plot_scatter(ax, False)
1224 self._update_layout()
1225 elif event.key in 'l':
1226 if len(self.mark_data) > 0:
1227 self.list_selection(self.mark_data)
1228 elif event.key in 'h':
1229 self.help_ax.set_visible(not self.help_ax.get_visible())
1230 self.fig.canvas.draw()
1233 def _on_select(self, eclick, erelease):
1234 """Handle selection events."""
1235 if eclick.dblclick:
1236 if len(self.mark_data) > 0:
1237 self.analyze_selection(self.mark_data[-1])
1238 return
1239 x0 = min(eclick.xdata, erelease.xdata)
1240 x1 = max(eclick.xdata, erelease.xdata)
1241 y0 = min(eclick.ydata, erelease.ydata)
1242 y1 = max(eclick.ydata, erelease.ydata)
1243 ax = erelease.inaxes
1244 if ax is None:
1245 ax = eclick.inaxes
1246 xmin, xmax = ax.get_xlim()
1247 ymin, ymax = ax.get_ylim()
1248 dx = 0.02*(xmax-xmin)
1249 dy = 0.02*(ymax-ymin)
1250 if x1 - x0 < dx and y1 - y0 < dy:
1251 bbox = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted())
1252 width, height = bbox.width, bbox.height
1253 width *= self.fig.dpi
1254 height *= self.fig.dpi
1255 dx = self.pick_radius*(xmax-xmin)/width
1256 dy = self.pick_radius*(ymax-ymin)/height
1257 x0 = erelease.xdata - dx
1258 x1 = erelease.xdata + dx
1259 y0 = erelease.ydata - dy
1260 y1 = erelease.ydata + dy
1261 elif self.select_zooms:
1262 self.zoom_stack.append((ax, xmin, xmax, ymin, ymax))
1263 self._set_limits(ax, x0, x1, y0, y1)
1264 self._make_selection(ax, erelease.key, x0, x1, y0, y1)
1265 self._update_selection()
1268 def _on_pick(self, event):
1269 """Handle pick events."""
1270 for ax in self.wave_ax:
1271 for k, l in enumerate(ax.lines):
1272 if l is event.artist:
1273 self.mark_data = [self.mark_data[k]]
1274 for ax in self.scatter_ax:
1275 if ax.collections[0] is event.artist:
1276 self.mark_data = event.ind
1277 self._update_selection()
1278 if event.mouseevent.dblclick:
1279 if len(self.mark_data) > 0:
1280 self.analyze_selection(self.mark_data[-1])
1283 def _set_layout(self, width, height):
1284 """Update positions and visibility of all plots."""
1285 xoffs = self.xborder/width
1286 yoffs = self.yborder/height
1287 xs = self.spacing/width
1288 ys = self.spacing/height
1289 if self.maxcols > 0:
1290 dx = (1.0-xoffs)/self.maxcols
1291 dy = (1.0-yoffs)/self.maxcols
1292 xw = dx - xs
1293 yw = dy - ys
1294 # histograms:
1295 for c, ax in enumerate(self.hist_ax):
1296 if c < self.maxcols:
1297 ax.set_position([xoffs+c*dx, yoffs, xw, yw])
1298 ax.set_visible(True)
1299 else:
1300 ax.set_visible(False)
1301 ax.set_position([0.99, 0.01, 0.01, 0.01])
1302 # scatter plots:
1303 for ax, (c, r) in zip(self.scatter_ax[:-1], self.scatter_indices[:-1]):
1304 if r < self.maxcols:
1305 ax.set_position([xoffs+c*dx, yoffs+(self.maxcols-r)*dy, xw, yw])
1306 ax.set_visible(True)
1307 else:
1308 ax.set_visible(False)
1309 ax.set_position([0.99, 0.01, 0.01, 0.01])
1310 # color bar:
1311 if self.maxcols > 0:
1312 self.cbax.set_position([xoffs+dx, yoffs+(self.maxcols-1)*dy, 0.3*xoffs, yw])
1313 self.cbax.set_visible(True)
1314 else:
1315 self.cbax.set_visible(False)
1316 self.cbax.set_position([0.99, 0.01, 0.01, 0.01])
1317 # magnified plot:
1318 if self.maxcols > 0:
1319 self._set_magnified_pos(width, height)
1320 if self.magnified_backdrop is not None:
1321 bbox = self.scatter_ax[-1].get_tightbbox(self.fig.canvas.get_renderer())
1322 if bbox is not None:
1323 self.magnified_backdrop.set_bounds(bbox.x0 - self.mborder,
1324 bbox.y0 - self.mborder,
1325 bbox.width + 2*self.mborder,
1326 bbox.height + 2*self.mborder)
1327 else:
1328 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05])
1329 self.scatter_ax[-1].set_visible(False)
1330 # waveform plots:
1331 if len(self.wave_ax) > 0:
1332 if self.maxcols > 0:
1333 x0 = xoffs+((self.maxcols+1)//2)*dx
1334 y0 = ((self.maxcols+1)//2)*dy
1335 if self.maxcols%2 == 0:
1336 x0 += xoffs
1337 y0 += yoffs - ys
1338 else:
1339 y0 += ys
1340 else:
1341 x0 = xoffs
1342 y0 = 0.0
1343 yp = 1.0
1344 dy = 1.0-y0
1345 dy -= np.sum(self.wave_has_xticks)*yoffs
1346 yp -= ys
1347 dy -= ys
1348 if self.wave_title:
1349 yp -= 2*ys
1350 dy -= 2*ys
1351 dy /= len(self.wave_ax)
1352 for ax, has_xticks in zip(self.wave_ax, self.wave_has_xticks):
1353 yp -= dy
1354 ax.set_position([x0, yp, 1.0-x0-xs, dy])
1355 if has_xticks:
1356 yp -= yoffs
1357 else:
1358 yp -= ys
1361 def _update_layout(self):
1362 """Update content and position of magnified plot."""
1363 if self.scatter_indices[-1][1] < self.data.shape[1]:
1364 if self.scatter_indices[-1][1] >= self.maxcols:
1365 self.scatter_indices[-1][1] = self.maxcols-1
1366 if self.scatter_indices[-1][0] >= self.scatter_indices[-1][1]:
1367 self.scatter_indices[-1][0] = self.scatter_indices[-1][1]-1
1368 self._plot_scatter(self.scatter_ax[-1], True)
1369 else:
1370 if self.scatter_indices[-1][0] >= self.maxcols:
1371 self.scatter_indices[-1][0] = self.maxcols-1
1372 self._plot_hist(self.scatter_ax[-1], True)
1373 self._set_hist_ylim()
1374 self._set_layout(self.fig.get_window_extent().width,
1375 self.fig.get_window_extent().height)
1376 self.fig.canvas.draw()
1379 def _on_resize(self, event):
1380 """Adapt layout of plots to new figure size."""
1381 self._set_layout(event.width, event.height)
1384def categorize(data):
1385 """Convert categorial string data into integer categories.
1387 Parameters
1388 ----------
1389 data: list or ndarray of str
1390 Data with textual categories.
1392 Returns
1393 -------
1394 categories: list of str
1395 A sorted unique list of the strings in `data`,
1396 that is the names of the categories.
1397 cdata: ndarray of int
1398 A copy of the input `data` where each string value is replaced
1399 by an integer number that is an index into the returned `categories`.
1400 """
1401 cats = sorted(set(data))
1402 cdata = np.array([cats.index(x) for x in data], dtype=int)
1403 return cats, cdata
1406def select_features(data, columns):
1407 """Assemble list of column indices.
1409 Parameters
1410 ----------
1411 data: TableData
1412 The table from which to select features.
1413 columns: list of str
1414 Feature names (column headers) to be selected from the data.
1415 If a column is listed twice (even times) it is not added.
1417 Returns
1418 -------
1419 data_cols: list of int
1420 List of indices into data columns for selecting features.
1421 """
1422 if len(columns) == 0:
1423 data_cols = list(np.arange(len(data)))
1424 else:
1425 data_cols = []
1426 for c in columns:
1427 idx = data.index(c)
1428 if idx is None:
1429 print(f'"{c}" is not a valid data column')
1430 elif idx in data_cols:
1431 data_cols.remove(idx)
1432 else:
1433 data_cols.append(idx)
1434 return data_cols
1437def select_coloring(data, data_cols, color_col):
1438 """Select column from data table for colorizing the data.
1440 Pass the output of this function on to MultivariateExplorer.set_colors().
1442 Parameters
1443 ----------
1444 data: TableData
1445 Table with all EOD properties from which columns are selected.
1446 data_cols: list of int
1447 List of columns selected to be explored.
1448 color_col: str or int
1449 Column to be selected for coloring the data.
1450 If 'row' then use the row index of the data in the table for coloring.
1452 Returns
1453 -------
1454 colors: int or list of values or None
1455 Either index of `data_cols` or additional data from the data table
1456 to be used for coloring.
1457 color_label: str or None
1458 Label for labeling the color bar.
1459 color_idx: int or None
1460 Index of color column in `data`.
1461 error: None or str
1462 In case an invalid column is selected, an error string.
1463 """
1464 color_idx = data.index(color_col)
1465 colors = None
1466 color_label = None
1467 if color_idx is None and color_col != 'row':
1468 return None, None, None, f'"{color_col}" is not a valid column for color code'
1469 if color_idx is None:
1470 colors = -2
1471 elif color_idx in data_cols:
1472 colors = data_cols.index(color_idx)
1473 else:
1474 if len(data.unit(color_idx)) > 0 and not data.unit(color_idx) in ['-', '1']:
1475 color_label = f'{data.label(color_idx)} [{data.unit(color_idx)}]'
1476 else:
1477 color_label = data.label(color_idx)
1478 colors = data[:, color_idx]
1479 return colors, color_label, color_idx, None
1482def list_available_features(data, data_cols=[], color_col=None):
1483 """Print available features on console.
1485 Parameters
1486 ----------
1487 data: TableData
1488 The full data table.
1489 data_cols: list of int
1490 List of indices of columns (features) in the table
1491 that are passed on to the MultivariateExplorer.
1492 color_col: int or None
1493 Index of columns (feature) in the table
1494 that is initially used for color coding the data.
1495 """
1496 print('available features:')
1497 for k, c in enumerate(data.keys()):
1498 s = [' '] * 3
1499 if k in data_cols:
1500 s[1] = '*'
1501 if color_col is not None and k == color_col:
1502 s[0] = 'C'
1503 print(''.join(s) + c)
1504 if len(data_cols) > 0:
1505 print('*: feature selected for exploration')
1506 if color_col is not None:
1507 print('C: feature selected for color coding the data')
1510class PrintHelp(argparse.Action):
1511 def __call__(self, parser, namespace, values, option_string):
1512 parser.print_help()
1513 print('')
1514 print('mouse:')
1515 for ma in MultivariateExplorer.mouse_actions:
1516 print('%-23s %s' % ma)
1517 print('%-23s %s' % ('double left click', 'run thunderfish on selected EOD waveform'))
1518 print('')
1519 print('key shortcuts:')
1520 for ka in MultivariateExplorer.key_actions:
1521 print('%-23s %s' % ka)
1522 parser.exit()
1525def demo():
1526 """Run the multivariate explorer with a random test data set.
1527 """
1528 # generate data:
1529 n = 100
1530 data = []
1531 data.append(np.random.randn(n) + 2.0)
1532 data.append(1.0+0.1*data[0] + 1.5*np.random.randn(n))
1533 data.append(10*(-3.0*data[0] + 2.0*data[1] + 1.8*np.random.randn(n)))
1534 idx = np.random.randint(0, 3, n)
1535 names = ['aaa', 'bbb', 'ccc']
1536 data.append([names[i] for i in idx])
1537 # generate waveforms:
1538 waveforms = []
1539 time = np.arange(0.0, 10.0, 0.01)
1540 for r in range(len(data[0])):
1541 x = data[0][r]*np.sin(2.0*np.pi*data[1][r]*time + data[2][r])
1542 y = data[0][r]*np.exp(-0.5*((time-data[1][r])/(0.2*data[2][r]))**2.0)
1543 waveforms.append(np.column_stack((time, x, y)))
1544 #waveforms.append([np.column_stack((time, x)), np.column_stack((time, y))])
1545 # initialize explorer:
1546 expl = MultivariateExplorer(data,
1547 list(map(chr, np.arange(len(data))+ord('A'))),
1548 'Explorer')
1549 expl.set_wave_data(waveforms, 'Time', ['Sine', 'Gauss'])
1550 # explore data:
1551 expl.set_colors()
1552 expl.show()
1555def main(*cargs):
1556 # parse command line:
1557 parser = argparse.ArgumentParser(add_help=False,
1558 description='View and explore multivariate data.',
1559 epilog = f'version {__version__} by Benda-Lab (2019-{__year__})')
1560 parser.add_argument('-h', '--help', nargs=0, action=PrintHelp,
1561 help='show this help message and exit')
1562 parser.add_argument('--version', action='version', version=__version__)
1563 parser.add_argument('-l', dest='list_features', action='store_true',
1564 help='list all available data columns (features) and exit')
1565 parser.add_argument('-d', dest='data_cols', action='append',
1566 default=[], metavar='COLUMN',
1567 help='data columns (features) to be explored')
1568 parser.add_argument('-c', dest='color_col', default=None,
1569 type=str, metavar='COLUMN',
1570 help='data column to be used for color code or "row"')
1571 parser.add_argument('-m', dest='color_map', default='jet',
1572 type=str, metavar='CMAP',
1573 help='name of color map to be used')
1574 parser.add_argument('file', nargs='?', default='', type=str,
1575 help='a file containing a table of data (csv file or similar)')
1576 if len(cargs) == 0:
1577 cargs = None
1578 args = parser.parse_args(cargs)
1579 if args.file:
1580 # load data:
1581 data = TableData(args.file)
1582 data_cols = select_features(data, args.data_cols)
1583 # select column used for coloring the data:
1584 colors, color_label, color_col, error = \
1585 select_coloring(data, data_cols, args.color_col)
1586 if error:
1587 parser.error(error)
1588 # list features:
1589 if args.list_features:
1590 list_available_features(data, data_cols, color_col)
1591 parser.exit()
1592 # explore data:
1593 expl = MultivariateExplorer(data[:, data_cols])
1594 expl.set_colors(colors, color_label, args.color_map)
1595 expl.show()
1596 else:
1597 demo()
1600if __name__ == '__main__':
1601 main(*sys.argv[1:])