Coverage for src/thunderlab/multivariateexplorer.py: 88%

981 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-18 22:10 +0000

1"""Simple GUI for viewing and exploring multivariate data. 

2 

3- `class MultiVariateExplorer`: simple matplotlib-based GUI for viewing and exploring multivariate data. 

4- `categorize()`: convert categorial string data into integer categories. 

5- `select_features()`: assemble list of column indices. 

6- `select_coloring()`: select column from data table for colorizing the data. 

7- `list_available_features()`: print available features on console. 

8""" 

9 

10import sys 

11import numpy as np 

12from scipy.stats import pearsonr 

13from sklearn import decomposition 

14from sklearn import preprocessing 

15import matplotlib.pyplot as plt 

16import matplotlib.patches as patches 

17import matplotlib.widgets as widgets 

18import argparse 

19from .version import __version__, __year__ 

20from .tabledata import TableData 

21 

22 

23class MultivariateExplorer(object): 

24 """Simple matplotlib-based GUI for viewing and exploring multivariate data. 

25 

26 Shown are scatter plots of all pairs of variables or PCA axis. 

27 Points in the scatter plots are colored according to the values of one of the variables. 

28 Data points can be selected and optionally corresponding waveforms are shown. 

29 

30 First you initialize the explorer with the data. Then you optionally 

31 specify how to colorize the data and provide waveform data 

32 associated with the data. Finally you show the figure: 

33 ``` 

34 expl = MultivariateExplorer(data) 

35 expl.set_colors(2) 

36 expl.set_wave_data(waveforms, 'Time [s]', 'Sine') 

37 expl.show() 

38 ``` 

39 

40 The `compute_pca() function computes a principal component analysis (PCA) 

41 on the input data, and `save_pca()` writes the principal components to a file. 

42 

43 Customize the appearance and information provided by subclassing 

44 MultivariateExplorer and reimplementing the functions 

45 - fix_scatter_plot() 

46 - fix_waveform_plot() 

47 - list_selection() 

48 - analyze_selection() 

49 See the documentation of these functions for details. 

50 """ 

51 

52 mouse_actions = [ 

53 ('left click', 'select sample'), 

54 ('left and drag', 'rectangular selection of samples and/or zoom'), 

55 ('shift + left click/drag', 'add samples to selection'), 

56 ('ctrl + left click/drag', 'remove samples from selection') 

57 ] 

58 """List of tuples with mouse actions and a description of their action.""" 

59 

60 key_actions = [ 

61 ('c, C', 'cycle color map trough data columns'), 

62 ('p,P', 'toggle between features, PCs, and scaled PCs'), 

63 ('<, pageup', 'decrease number of displayed featured/PCs'), 

64 ('>, pagedown', 'increase number of displayed features/PCs'), 

65 ('o, z', 'toggle zoom mode on or off'), 

66 ('backspace', 'zoom back'), 

67 ('n, N', 'decrease, increase number of bins of histograms'), 

68 ('H', 'toggle between scatter plot and 2D histogram'), 

69 ('left, right, up, down', 'show and move magnified scatter plot'), 

70 ('escape', 'close magnified scatter plot'), 

71 ('ctrl + a', 'select all'), 

72 ('+, -', 'increase, decrease pick radius'), 

73 ('0', 'reset pick radius'), 

74 ('l', 'list selection on console'), 

75 ('w', 'toggle maximized waveform plot'), 

76 ('h', 'toggle help window'), 

77 ] 

78 """List of tuples with key shortcuts and a description of their action.""" 

79 

80 def __init__(self, data, labels=None, title=None): 

81 """Initialize explorer with scatter-plot data. 

82 

83 Parameters 

84 ---------- 

85 data: TableData, 2D array, or list of 1D arrays 

86 The data to be explored. Each column is a variable. 

87 For the 2D array the columns are the second dimension, 

88 for a list of 1D arrays, the list goes over columns, 

89 i.e. each 1D array is one column. 

90 labels: list of str 

91 If data is not a TableData, then this provides labels 

92 for the data columns. 

93 title: str 

94 Title for the window. 

95 """ 

96 # data. categories and labels: 

97 self.raw_data = None # original data table as 2D numpy array (samples x features) 

98 self.raw_labels = None # for each feature a label, optional with unit 

99 self.categories = [] # for each feature None or list of categories 

100 if isinstance(data, TableData): 

101 for c, col in enumerate(data): 

102 if not isinstance(data[col][0], (int, float, 

103 np.integer, np.floating)): 

104 # categorial data: 

105 cats, data[:,c] = categorize(data[col]) 

106 self.categories.append(cats) 

107 else: 

108 self.categories.append(None) 

109 self.raw_data = data.array() 

110 if labels is None: 

111 self.raw_labels = [] 

112 for c in range(data.columns()): 

113 if len(data.unit(c)) > 0 and not data.unit(c) in ['-', '1']: 

114 self.raw_labels.append(f'{data.label(c)} [{data.unit(c)}]') 

115 else: 

116 self.raw_labels.append(data.label(c)) 

117 else: 

118 self.raw_labels = labels 

119 else: 

120 if isinstance(data, np.ndarray): 

121 self.raw_data = data 

122 self.categories = [None] * data.shape[1] 

123 else: 

124 for c, col in enumerate(data): 

125 if not isinstance(col[0], (int, float, 

126 np.integer, np.floating)): 

127 # categorial data: 

128 cats, data[c] = categorize(col) 

129 self.categories.append(cats) 

130 else: 

131 self.categories.append(None) 

132 self.raw_data = np.asarray(data).T 

133 self.raw_labels = labels 

134 # remove columns containing only invalid numbers: 

135 cols = np.all(~np.isfinite(self.raw_data), 0) 

136 if np.sum(cols) > 0: 

137 print('removed columns containing no numbers:', 

138 [l for l, c in zip(self.raw_labels, cols) if c]) 

139 self.raw_data = self.raw_data[:, ~cols] 

140 self.raw_labels = [l for l, c in zip(self.raw_labels, cols) if not c] 

141 # remove rows containing invalid numbers: 

142 self.valid_samples = ~np.any(~np.isfinite(self.raw_data), 1) 

143 self.raw_data = self.raw_data[self.valid_samples, :] 

144 if np.sum(~self.valid_samples) > 0: 

145 print(f'removed {np.sum(~self.valid_samples)} rows containing invalid numbers:') 

146 for k in range(len(self.valid_samples)): 

147 if not self.valid_samples[k]: 

148 print(k) 

149 self.valid_rows = [k for k in range(len(self.valid_samples)) 

150 if self.valid_samples[k]] 

151 # title for the window: 

152 self.title = title if title is not None else 'MultivariateExplorer' 

153 # data, pca-data, scaled-pca data (no pca data yet): 

154 self.all_data = [self.raw_data, None, None] 

155 self.all_labels = [self.raw_labels, None, None] 

156 self.all_maxcols = [self.raw_data.shape[1], None, None] 

157 self.all_titles = ['data', 'PCA', 'scaled PCA'] # added to window title  

158 # pca: 

159 self.pca_tables = [None, None] # raw and scaled pca coefficients 

160 self._pca_header(self.raw_data, self.raw_labels) # prepare header of the pca tables 

161 # start showing raw data: 

162 self.show_mode = 0 # show data, pca or scaled pca 

163 self.data = self.all_data[self.show_mode] # the data shown 

164 self.labels = self.all_labels[self.show_mode] # the feature labels shown 

165 self.maxcols = self.all_maxcols[self.show_mode] # maximum number of features currently shown 

166 if self.maxcols > 6: 

167 self.maxcols = 6 

168 # waveform data: 

169 self.wave_data = [] 

170 self.wave_nested = False 

171 self.wave_has_xticks = [] 

172 self.wave_xlabels = [] 

173 self.wave_ylabels = [] 

174 self.wave_title = False 

175 # colors: 

176 self.color_map = plt.get_cmap('jet') 

177 self.extra_colors = None # additional data column to be used for coloring 

178 self.extra_color_label = None # label for extra_colors 

179 self.extra_categories = None # category name for extra_colors if needed 

180 self.color_set_index = 0 # -1: rows and extra_colors, 0: data, 1: pca, 2: scaled pca 

181 self.color_index = 0 # column used for coloring with color_set_index 

182 self.color_values = None # data column currently used for coloring as specified by color_set_index and color_index 

183 self.color_label = None # label of data currently used for coloring 

184 self.data_colors = None # actual colors for color_values 

185 self.color_vmin = None 

186 self.color_vmax = None 

187 self.color_ticks = None 

188 self.cbax = None # axes of color bar 

189 # figure variables: 

190 self.plt_params = {} 

191 for k in ['toolbar', 'keymap.quit', 'keymap.back', 'keymap.forward', 

192 'keymap.zoom', 'keymap.pan', 'keymap.xscale', 'keymap.yscale']: 

193 self.plt_params[k] = plt.rcParams[k] 

194 if k != 'toolbar': 

195 plt.rcParams[k] = '' 

196 self.xborder = 100.0 # pixel for ylabels 

197 self.yborder = 50.0 # pixel for xlabels 

198 self.spacing = 10.0 # pixel between plots 

199 self.mborder = 20.0 # pixel around magnified plot 

200 self.pick_radius = 4.0 

201 # histogram plots: 

202 self.hist_ax = [] # list of histogram axes 

203 self.hist_indices = [] # feature index of the histogram axes 

204 self.hist_selector = [] # for each histogram axes a selector 

205 self.hist_nbins = 30 # number of bins for computing histograms 

206 # scatter plots: 

207 self.scatter_ax = [] # list of axes with scatter plots (1D) 

208 self.scatter_indices = [] # for each axes a tuple of the column and row index 

209 self.scatter_artists = [] # artists of selected scatter points 

210 self.scatter_selector = [] # selector for each axes 

211 self.scatter = True # scatter (True) or density (False) 

212 self.mark_data = [] # list of indices of selected data 

213 self.significance_level = 0.05 # r is bold if p is below 

214 self.select_zooms = False 

215 self.zoom_stack = [] 

216 # magnified scatter plot: 

217 self.magnified_on = False 

218 self.magnified_backdrop = None 

219 self.magnified_size = np.array([0.6, 0.6]) 

220 # waveform plots: 

221 self.wave_ax = [] 

222 # help window: 

223 self.help_ax = None 

224 

225 

226 def set_wave_data(self, data, xlabels='', ylabels=[], title=False): 

227 """Add waveform data to explorer. 

228 

229 Parameters 

230 ---------- 

231 data: list of (list of) 2D arrays 

232 Waveform data associated with each row of the data. 

233 Elements of the outer list correspond to the rows of the data. 

234 The inner 2D arrays contain a common x-axes (first column) 

235 and one or more corresponding y-values (second and optional higher columns). 

236 Each column for y-values is plotted in its own axes on top of each other, 

237 from top to bottom. 

238 The optional inner list of 2D arrays contains several 2D arrays as ascribed above 

239 each with its own common x-axes. 

240 xlabel: str or list of str 

241 The xlabels for the waveform plots. If only a string is given, then 

242 there will be a common xaxis for all the plots, and only the lowest 

243 one gets a labeled xaxis. If a list of strings is given, each waveform 

244 plot gets its own labeled x-axis. 

245 ylabels: list of str 

246 The ylabels for each of the waveform plots. 

247 title: bool or str 

248 If True or a string, povide space on top of the waveform plots for a title. 

249 If string, set this as the title for the waveform plots. 

250 """ 

251 self.wave_data = [] 

252 if data is not None and len(data) > 0: 

253 self.wave_data = data 

254 self.wave_has_xticks = [] 

255 self.wave_nested = isinstance(data[0], (list, tuple)) 

256 if self.wave_nested: 

257 for data in self.wave_data[0]: 

258 for k in range(data.shape[1]-2): 

259 self.wave_has_xticks.append(False) 

260 self.wave_has_xticks.append(True) 

261 else: 

262 for k in range(self.wave_data[0].shape[1]-2): 

263 self.wave_has_xticks.append(False) 

264 self.wave_has_xticks.append(True) 

265 if isinstance(xlabels, (list, tuple)): 

266 self.wave_xlabels = xlabels 

267 else: 

268 self.wave_xlabels = [xlabels] 

269 self.wave_ylabels = ylabels 

270 self.wave_title = title 

271 self.wave_ax = [] 

272 

273 

274 def set_colors(self, colors=0, color_label=None, color_map=None): 

275 """Set data column used to color scatter plots. 

276  

277 Parameters 

278 ---------- 

279 colors: int or 1D array 

280 Index to colum in data to be used for coloring scatter plots. 

281 -2 for coloring row index of data. 

282 Or data array used to color scalar plots. 

283 color_label: str 

284 If colors is an array, this is a label describing the data. 

285 It is used to label the color bar. 

286 color_map: str or None 

287 Name of a matplotlib color map. 

288 If None 'jet' is used. 

289 """ 

290 if isinstance(colors, (np.integer, int)): 

291 if colors < 0: 

292 self.color_set_index = -1 

293 self.color_index = 0 

294 else: 

295 self.color_set_index = 0 

296 self.color_index = colors 

297 else: 

298 if not isinstance(colors[0], (int, float, 

299 np.integer, np.floating)): 

300 # categorial data: 

301 self.extra_categories, self.extra_colors = categorize(colors) 

302 else: 

303 self.extra_colors = colors 

304 self.extra_colors = self.extra_colors[self.valid_samples] 

305 self.extra_color_label = color_label 

306 self.color_set_index = -1 

307 self.color_index = 1 

308 self.color_map = plt.get_cmap(color_map if color_map else 'jet') 

309 

310 

311 def show(self, ioff=True): 

312 """Show interactive scatter plots for exploration. 

313 """ 

314 if ioff: 

315 plt.ioff() 

316 else: 

317 plt.ion() 

318 plt.rcParams['toolbar'] = 'None' 

319 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q' 

320 plt.rcParams['font.size'] = 12 

321 self.fig = plt.figure(facecolor='white', figsize=(10, 8)) 

322 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) 

323 self.fig.canvas.mpl_connect('key_press_event', self._on_key) 

324 self.fig.canvas.mpl_connect('resize_event', self._on_resize) 

325 self.fig.canvas.mpl_connect('pick_event', self._on_pick) 

326 if self.color_map is None: 

327 self.color_map = plt.get_cmap('jet') 

328 self._set_color_column() 

329 self._init_hist_plots() 

330 self._init_scatter_plots() 

331 self.wave_ax = [] 

332 if self.wave_data is not None and len(self.wave_data) > 0: 

333 axx = None 

334 xi = 0 

335 for k, has_xticks in enumerate(self.wave_has_xticks): 

336 ax = self.fig.add_subplot(1, len(self.wave_has_xticks), 

337 1+k, sharex=axx) 

338 self.wave_ax.append(ax) 

339 if has_xticks: 

340 if xi >= len(self.wave_xlabels): 

341 self.wave_xlabels.append('') 

342 ax.set_xlabel(self.wave_xlabels[xi]) 

343 xi += 1 

344 axx = None 

345 else: 

346 #ax.xaxis.set_major_formatter(plt.NullFormatter()) 

347 if axx is None: 

348 axx = ax 

349 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): 

350 ax.set_ylabel(ylabel) 

351 if not isinstance(self.wave_title, bool) and self.wave_title: 

352 self.wave_ax[0].set_title(self.wave_title) 

353 self.fix_waveform_plot(self.wave_ax, self.mark_data) 

354 self._plot_magnified_scatter() 

355 self._plot_help() 

356 plt.show() 

357 

358 

359 def _pca_header(self, data, labels): 

360 """Set up header for the table of principal components. 

361 

362 Parameters 

363 ---------- 

364 data: ndarray of float 

365 The data (samples x features) without invalid (infinite or 

366 NaN) numbers. 

367 labels: list of str 

368 Labels of the features. 

369 """ 

370 lbs = [] 

371 for l, d in zip(labels, data): 

372 if '[' in l: 

373 lbs.append(l.split('[')[0].strip()) 

374 elif '/' in l: 

375 lbs.append(l.split('/')[0].strip()) 

376 else: 

377 lbs.append(l) 

378 header = TableData(header=lbs) 

379 header.set_formats('%.3f') 

380 header.insert(0, ['PC'] + ['-']*header.nsecs, '', '%d') 

381 header.insert(1, 'variance', '%', '%.3f') 

382 for k in range(len(self.pca_tables)): 

383 self.pca_tables[k] = TableData(header) 

384 

385 

386 def compute_pca(self, scale=False, write=False): 

387 """Compute PCA based on the data. 

388 

389 Parameters 

390 ---------- 

391 scale: boolean 

392 If True standardize data before computing PCA, i.e. remove mean 

393 of each variabel and divide by its standard deviation. 

394 write: boolean 

395 If True write PCA components to standard out. 

396 """ 

397 # pca: 

398 pca = decomposition.PCA() 

399 if scale: 

400 scaler = preprocessing.StandardScaler() 

401 scaler.fit(self.raw_data) 

402 pca.fit(scaler.transform(self.raw_data)) 

403 pca_label = 'sPC' 

404 else: 

405 pca.fit(self.raw_data) 

406 pca_label = 'PC' 

407 for k in range(len(pca.components_)): 

408 if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]): 

409 pca.components_[k] *= -1.0 

410 pca_data = pca.transform(self.raw_data) 

411 pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%')) 

412 for k, v in enumerate(pca.explained_variance_ratio_)] 

413 if np.min(pca.explained_variance_ratio_) >= 0.01: 

414 pca_maxcols = pca_data.shape[1] 

415 else: 

416 pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01) 

417 if pca_maxcols < 2: 

418 pca_maxcols = 2 

419 if pca_maxcols > 6: 

420 pca_maxcols = 6 

421 # table with PCA feature weights: 

422 pca_table = self.pca_tables[1] if scale else self.pca_tables[0] 

423 pca_table.clear_data() 

424 pca_table.set_section(0, pca_label, pca_table.nsecs) 

425 for k, comp in enumerate(pca.components_): 

426 pca_table.add(k+1, 0) 

427 pca_table.add(100.0*pca.explained_variance_ratio_[k]) 

428 pca_table.add(comp) 

429 if write: 

430 pca_table.write(table_format='out', unit_style='none') 

431 # submit data: 

432 if scale: 

433 self.all_data[2] = pca_data 

434 self.all_labels[2] = pca_labels 

435 self.all_maxcols[2] = pca_maxcols 

436 else: 

437 self.all_data[1] = pca_data 

438 self.all_labels[1] = pca_labels 

439 self.all_maxcols[1] = pca_maxcols 

440 

441 

442 def save_pca(self, file_name, scale, **kwargs): 

443 """Write PCA data to file. 

444 

445 Parameters 

446 ---------- 

447 file_name: str 

448 Name of ouput file. 

449 scale: boolean 

450 If True write PCA components of standardized PCA. 

451 kwargs: dict 

452 Additional parameter for TableData.write() 

453 """ 

454 if scale: 

455 pca_file = file_name + '-pcacor' 

456 pca_table = self.pca_tables[1] 

457 else: 

458 pca_file = file_name + '-pcacov' 

459 pca_table = self.pca_tables[0] 

460 if 'unit_style' in kwargs: 

461 del kwargs['unit_style'] 

462 if 'table_format' in kwargs: 

463 pca_table.write(pca_file, unit_style='none', **kwargs) 

464 else: 

465 pca_file += '.dat' 

466 pca_table.write(pca_file, unit_style='none') 

467 

468 

469 def _set_color_column(self): 

470 """Initialize variables used for colorization of scatter points.""" 

471 if self.color_set_index == -1: 

472 if self.color_index == 0: 

473 self.color_values = np.arange(self.data.shape[0], dtype=float) 

474 self.color_label = 'sample' 

475 elif self.color_index == 1: 

476 self.color_values = self.extra_colors 

477 self.color_label = self.extra_color_label 

478 else: 

479 self.color_values = self.all_data[self.color_set_index][:,self.color_index] 

480 self.color_label = self.all_labels[self.color_set_index][self.color_index] 

481 self.color_vmin, self.color_vmax, self.color_ticks = \ 

482 self.fix_scatter_plot(self.cbax, self.color_values, 

483 self.color_label, 'c') 

484 if self.color_ticks is None: 

485 if self.color_set_index == 0 and \ 

486 self.categories[self.color_index] is not None: 

487 self.color_ticks = np.arange(len(self.categories[self.color_index])) 

488 elif self.color_set_index == -1 and \ 

489 self.color_index == 1 and \ 

490 self.extra_categories is not None: 

491 self.color_ticks = np.arange(len(self.extra_categories)) 

492 self.data_colors = self.color_map((self.color_values - self.color_vmin)/(self.color_vmax - self.color_vmin)) 

493 

494 

495 def _add_backdrop(self, ax): 

496 bbox = ax.get_tightbbox(self.fig.canvas.get_renderer()) 

497 if bbox is not None: 

498 self.magnified_backdrop = \ 

499 patches.Rectangle((bbox.x0 - self.mborder, 

500 bbox.y0 - self.mborder), 

501 bbox.width + 2*self.mborder, 

502 bbox.height + 2*self.mborder, 

503 transform=None, clip_on=False, 

504 facecolor='#f7f7f7', edgecolor='none', 

505 zorder=-5) 

506 ax.add_patch(self.magnified_backdrop) 

507 

508 

509 def _create_selector(self, ax): 

510 try: 

511 selector = \ 

512 widgets.RectangleSelector(ax, self._on_select, 

513 useblit=True, button=1, 

514 minspanx=0, minspany=0, 

515 spancoords='pixels', 

516 props=dict(facecolor='gray', 

517 edgecolor='gray', 

518 alpha=0.2, 

519 fill=True), 

520 state_modifier_keys=dict(move='', 

521 clear='', 

522 square='', 

523 center='ctrl')) 

524 except TypeError: 

525 # old matplotlib: 

526 selector = widgets.RectangleSelector(ax, self._on_select, 

527 useblit=True, button=1) 

528 return selector 

529 

530 

531 def _plot_hist(self, ax, magnifiedax): 

532 """Plot and label a histogram.""" 

533 try: 

534 idx = self.hist_ax.index(ax) 

535 c = self.hist_indices[idx] 

536 in_hist = True 

537 except ValueError: 

538 idx = self.scatter_ax.index(ax) 

539 c = self.scatter_indices[idx][0] 

540 in_hist = False 

541 ax.clear() 

542 #ax.relim() 

543 #ax.autoscale(True) 

544 x = self.data[:,c] 

545 ax.hist(x, self.hist_nbins) 

546 #ax.autoscale(False) 

547 ax.set_xlabel(self.labels[c]) 

548 ax.xaxis.set_major_locator(plt.AutoLocator()) 

549 ax.xaxis.set_major_formatter(plt.ScalarFormatter()) 

550 if self.show_mode == 0: 

551 if self.categories[c] is not None: 

552 ax.set_xticks(np.arange(len(self.categories[c]))) 

553 ax.set_xticklabels(self.categories[c]) 

554 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x') 

555 if magnifiedax: 

556 ax.text(0.05, 0.9, f'n={len(self.data)}', 

557 transform=ax.transAxes, zorder=100) 

558 ax.set_ylabel('count') 

559 cax = self.hist_ax[self.scatter_indices[-1][0]] 

560 ax.set_xlim(cax.get_xlim()) 

561 else: 

562 if c == 0: 

563 ax.text(0.05, 0.9, f'n={len(self.data)}', 

564 transform=ax.transAxes, zorder=100) 

565 ax.set_ylabel('count') 

566 else: 

567 ax.yaxis.set_major_formatter(plt.NullFormatter()) 

568 selector = self._create_selector(ax) 

569 if in_hist: 

570 self.hist_selector[idx] = selector 

571 else: 

572 self.scatter_selector[idx] = selector 

573 self.scatter_artists[idx] = None 

574 ax.relim(True) 

575 if magnifiedax: 

576 self._add_backdrop(ax) 

577 

578 

579 def _set_hist_ylim(self): 

580 ymax = np.max([ax.get_ylim() for ax in self.hist_ax[:self.maxcols]], 0)[1] 

581 for ax in self.hist_ax: 

582 ax.set_ylim(0, ymax) 

583 

584 

585 def _init_hist_plots(self): 

586 """Initial plots of the histograms.""" 

587 n = self.data.shape[1] 

588 self.hist_ax = [] 

589 for r in range(n): 

590 ax = self.fig.add_subplot(n, n, (n-1)*n+r+1) 

591 self.hist_ax.append(ax) 

592 self.hist_indices.append(r) 

593 self.hist_selector.append(None) 

594 self._plot_hist(ax, False) 

595 self._set_hist_ylim() 

596 

597 

598 def _plot_scatter(self, ax, magnifiedax, cax=None): 

599 """Plot a scatter plot.""" 

600 idx = self.scatter_ax.index(ax) 

601 c, r = self.scatter_indices[idx] 

602 if self.scatter: # scatter plot 

603 ax.clear() 

604 a = ax.scatter(self.data[:,c], self.data[:,r], s=50, 

605 edgecolors='white', linewidths=0.5, 

606 picker=self.pick_radius, zorder=10) 

607 a.set_facecolor(self.data_colors) 

608 pr, pp = pearsonr(self.data[:,c], self.data[:,r]) 

609 fw = 'bold' if pp < self.significance_level/self.data.shape[1] else 'normal' 

610 if pr < 0: 

611 ax.text(0.95, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, 

612 transform=ax.transAxes, ha='right', zorder=100) 

613 else: 

614 ax.text(0.05, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, 

615 transform=ax.transAxes, zorder=100) 

616 # color bar: 

617 if cax is not None: 

618 a = ax.scatter(self.data[:, c], self.data[:, r], 

619 c=self.color_values, cmap=self.color_map) 

620 self.fig.colorbar(a, cax=cax, ticks=self.color_ticks) 

621 a.remove() 

622 cax.set_ylabel(self.color_label) 

623 self.color_vmin, self.color_vmax, self.color_ticks = \ 

624 self.fix_scatter_plot(self.cbax, self.color_values, 

625 self.color_label, 'c') 

626 if self.color_ticks is None: 

627 if self.color_set_index == 0 and \ 

628 self.categories[self.color_index] is not None: 

629 cax.set_yticklabels(self.categories[self.color_index]) 

630 elif self.color_set_index == -1 and \ 

631 self.color_index == 1 and \ 

632 self.extra_categories is not None: 

633 cax.set_yticklabels(self.extra_categories) 

634 else: # histogram 

635 if self.show_mode == 0: 

636 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x') 

637 self.fix_scatter_plot(ax, self.data[:,r], self.labels[r], 'y') 

638 axrange = [ax.get_xlim(), ax.get_ylim()] 

639 ax.clear() 

640 ax.hist2d(self.data[:,c], self.data[:,r], self.hist_nbins, 

641 range=axrange, cmap=plt.get_cmap('Greys')) 

642 # selected data: 

643 a = ax.scatter(self.data[self.mark_data, c], 

644 self.data[self.mark_data, r], s=100, 

645 edgecolors='black', linewidths=0.5, 

646 picker=self.pick_radius, zorder=11) 

647 a.set_facecolor(self.data_colors[self.mark_data]) 

648 self.scatter_artists[idx] = a 

649 ax.xaxis.set_major_locator(plt.AutoLocator()) 

650 ax.yaxis.set_major_locator(plt.AutoLocator()) 

651 ax.xaxis.set_major_formatter(plt.ScalarFormatter()) 

652 ax.yaxis.set_major_formatter(plt.ScalarFormatter()) 

653 if self.show_mode == 0: 

654 if self.categories[c] is not None: 

655 ax.set_xticks(np.arange(len(self.categories[c]))) 

656 ax.set_xticklabels(self.categories[c]) 

657 if self.categories[r] is not None: 

658 ax.set_yticks(np.arange(len(self.categories[r]))) 

659 ax.set_yticklabels(self.categories[r]) 

660 if magnifiedax: 

661 ax.set_xlabel(self.labels[c]) 

662 ax.set_ylabel(self.labels[r]) 

663 cax = self.scatter_ax[self.scatter_indices[:-1].index(self.scatter_indices[-1])] 

664 ax.set_xlim(cax.get_xlim()) 

665 ax.set_ylim(cax.get_ylim()) 

666 else: 

667 if c == 0: 

668 ax.set_ylabel(self.labels[r]) 

669 if self.show_mode == 0: 

670 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x') 

671 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y') 

672 if not magnifiedax: 

673 ax.xaxis.set_major_formatter(plt.NullFormatter()) 

674 if c > 0: 

675 ax.yaxis.set_major_formatter(plt.NullFormatter()) 

676 ax.set_xlim(*self.hist_ax[c].get_xlim()) 

677 ax.set_ylim(*self.hist_ax[r].get_xlim()) 

678 if magnifiedax: 

679 self._add_backdrop(ax) 

680 selector = self._create_selector(ax) 

681 self.scatter_selector[idx] = selector 

682 ax.relim(True) 

683 

684 

685 def _init_scatter_plots(self): 

686 """Initial plots of scatter plots.""" 

687 self.cbax = self.fig.add_axes([0.5, 0.5, 0.1, 0.5]) 

688 cbax = self.cbax 

689 n = self.data.shape[1] 

690 for r in range(1, n): 

691 for c in range(r): 

692 ax = self.fig.add_subplot(n, n, (r-1)*n+c+1) 

693 self.scatter_ax.append(ax) 

694 self.scatter_indices.append([c, r]) 

695 self.scatter_artists.append(None) 

696 self.scatter_selector.append(None) 

697 self._plot_scatter(ax, False, cbax) 

698 cbax = None 

699 

700 

701 def _plot_magnified_scatter(self): 

702 """Initial plot of the magnified scatter plot.""" 

703 ax = self.fig.add_axes([0.5, 0.9, 0.05, 0.05]) 

704 ax.set_visible(False) 

705 self.magnified_on = False 

706 c = 0 

707 r = 1 

708 a = ax.scatter(self.data[:, c], self.data[:, r], 

709 s=50, edgecolors='none') 

710 a.set_facecolor(self.data_colors) 

711 a = ax.scatter(self.data[self.mark_data, c], 

712 self.data[self.mark_data, r], s=80) 

713 a.set_facecolor(self.data_colors[self.mark_data]) 

714 ax.set_xlabel(self.labels[c]) 

715 ax.set_ylabel(self.labels[r]) 

716 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x') 

717 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y') 

718 self.scatter_ax.append(ax) 

719 self.scatter_indices.append([c, r]) 

720 self.scatter_artists.append(a) 

721 self.scatter_selector.append(None) 

722 

723 

724 def _plot_help(self): 

725 ax = self.fig.add_subplot(1, 1, 1) 

726 ax.set_position([0.02, 0.02, 0.96, 0.96]) 

727 ax.xaxis.set_major_locator(plt.NullLocator()) 

728 ax.yaxis.set_major_locator(plt.NullLocator()) 

729 n = len(self.mouse_actions) + len(self.key_actions) + 4 

730 dy = 1/n 

731 y = 1 - 1.5*dy 

732 ax.text(0.05, y, 'Key shortcuts', transform=ax.transAxes, 

733 fontweight='bold') 

734 y -= dy 

735 for a, d in self.key_actions: 

736 ax.text(0.05, y, a, transform=ax.transAxes) 

737 ax.text(0.3, y, d, transform=ax.transAxes) 

738 y -= dy 

739 y -= dy 

740 ax.text(0.05, y, 'Mouse actions', transform=ax.transAxes, 

741 fontweight='bold') 

742 y -= dy 

743 for a, d in self.mouse_actions: 

744 ax.text(0.05, y, a, transform=ax.transAxes) 

745 ax.text(0.3, y, d, transform=ax.transAxes) 

746 y -= dy 

747 ax.set_visible(False) 

748 self.help_ax = ax 

749 

750 

751 def fix_scatter_plot(self, ax, data, label, axis): 

752 """Customize an axes of a scatter plot. 

753 

754 This function is called after a scatter plot has been plotted. 

755 Once for the x axes, once for the y axis and once for the color bar. 

756 Reimplement this function to set appropriate limits and ticks. 

757 

758 Return values are only used for the color bar (`axis='c'`). 

759 Otherwise they are ignored. 

760 

761 For example, ticks for phase variables can be nicely labeled 

762 using the unicode character for pi: 

763 ``` 

764 if 'phase' in label: 

765 if axis == 'y': 

766 ax.set_ylim(0.0, 2.0*np.pi) 

767 ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi)) 

768 ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0']) 

769 ``` 

770  

771 Parameters 

772 ---------- 

773 ax: matplotlib axes 

774 Axes of the scatter plot or color bar to be worked on. 

775 data: 1D array 

776 Data array of the axes. 

777 label: str 

778 Label coresponding to the data array. 

779 axis: str 

780 'x', 'y': set properties of x or y axes of ax. 

781 'c': set properies of color bar axes (note that ax can be None!) 

782 and return vmin, vmax, and ticks. 

783 

784 Returns 

785 ------- 

786 min: float 

787 minimum value of color bar axis 

788 max: float 

789 maximum value of color bar axis 

790 ticks: list of float 

791 position of ticks for color bar axis 

792 """ 

793 return np.nanmin(data), np.nanmax(data), None 

794 

795 

796 def fix_waveform_plot(self, axs, indices): 

797 """Customize waveform plots. 

798 

799 This function is called once after new data have been plotted 

800 into the waveform plots. Reimplement this function to customize 

801 these plots. In particular to set axis limits and labels, plot 

802 title, etc. 

803 You may even open a new figure (with non-blocking `show()`). 

804 

805 The following member variables might be usefull: 

806 - `self.wave_data`: the full list of waveform data. 

807 - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values. 

808 - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks. 

809 - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True). 

810 - `self.wave_ylabels`: for each axis its ylabel 

811  

812 For example, you can set the linewidth of all plotted waveforms via: 

813 ``` 

814 for ax in axs: 

815 for l in ax.lines: 

816 l.set_linewidth(3.0) 

817 ``` 

818 or enable markers to be plotted: 

819 ``` 

820 for ax, yl in zip(axs, self.wave_ylabels): 

821 if 'Power' in yl: 

822 for l in ax.lines: 

823 l.set_marker('.') 

824 l.set_markersize(15.0) 

825 l.set_markeredgewidth(0.5) 

826 l.set_markeredgecolor('k') 

827 l.set_markerfacecolor(l.get_color()) 

828 ``` 

829 Usefull is to reduce the maximum number of y-ticks: 

830 ``` 

831 axs[0].yaxis.get_major_locator().set_params(nbins=7) 

832 ``` 

833 or 

834 ``` 

835 import matplotlib.ticker as ticker 

836 axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4)) 

837 ``` 

838 

839 Parameters 

840 ---------- 

841 axs: list of matplotlib axes 

842 Axis of the waveform plots to be worked on. 

843 indices: list of int 

844 Indices of the waveforms that have been selected and plotted. 

845 """ 

846 pass 

847 

848 

849 def list_selection(self, indices): 

850 """List information about the current selection of data points. 

851 

852 This function is called when 'l' is pressed. Reimplement this 

853 function, for example, to print some meaningfull information 

854 about the current selection of data points on console. You may 

855 do, however, whatever you want in this function. 

856 

857 Parameters 

858 ---------- 

859 indices: list of int 

860 Indices of the data points that have been selected. 

861 """ 

862 print('') 

863 print('selected rows in data table:') 

864 for i in indices: 

865 print(self.valid_rows[i]) 

866 

867 

868 def analyze_selection(self, index): 

869 """Provide further information about a single selected data point. 

870 

871 This function is called when a single data item was double 

872 clicked. Reimplement this function to provide some further 

873 details on this data point. This can be an additional figure 

874 window. In this case show it non-blocking: 

875 `plt.show(block=False)` 

876 

877 Parameters 

878 ---------- 

879 index: int 

880 The index of the selected data point. 

881 """ 

882 pass 

883 

884 

885 def _set_magnified_pos(self, width, height): 

886 """Set position of magnified plot.""" 

887 if self.magnified_on: 

888 xoffs = self.xborder/width 

889 yoffs = self.yborder/height 

890 if self.scatter_indices[-1][1] < self.data.shape[1]: 

891 idx = self.scatter_indices[:-1].index(self.scatter_indices[-1]) 

892 pos = self.scatter_ax[idx].get_position().get_points() 

893 else: 

894 pos = self.hist_ax[self.scatter_indices[-1][0]].get_position().get_points() 

895 pos[0] = np.mean(pos, 0) - 0.5*self.magnified_size 

896 if pos[0][0] < xoffs: pos[0][0] = xoffs 

897 if pos[0][1] < yoffs: pos[0][1] = yoffs 

898 pos[1] = pos[0] + self.magnified_size 

899 if pos[1][0] > 1.0-self.spacing/width: pos[1][0] = 1.0-self.spacing/width 

900 if pos[1][1] > 1.0-self.spacing/height: pos[1][1] = 1.0-self.spacing/height 

901 pos[0] = pos[1] - self.magnified_size 

902 self.scatter_ax[-1].set_position([pos[0][0], pos[0][1], 

903 self.magnified_size[0], self.magnified_size[1]]) 

904 self.scatter_ax[-1].set_visible(True) 

905 else: 

906 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

907 self.scatter_ax[-1].set_visible(False) 

908 

909 

910 def _make_selection(self, ax, key, x0, x1, y0, y1): 

911 """Select points from a scatter or histogram plot.""" 

912 if not key in ['shift', 'control']: 

913 self.mark_data = [] 

914 if ax in self.scatter_ax: 

915 axi = self.scatter_ax.index(ax) 

916 # from scatter plots: 

917 c, r = self.scatter_indices[axi] 

918 if r < self.data.shape[1]: 

919 # from scatter: 

920 for ind, (x, y) in enumerate(zip(self.data[:, c], self.data[:, r])): 

921 if x >= x0 and x <= x1 and y >= y0 and y <= y1: 

922 if ind in self.mark_data: 

923 if key == 'control': 

924 self.mark_data.remove(ind) 

925 elif key != 'control': 

926 self.mark_data.append(ind) 

927 else: 

928 # from histogram: 

929 for ind, x in enumerate(self.data[:, c]): 

930 if x >= x0 and x <= x1: 

931 if ind in self.mark_data: 

932 if key == 'control': 

933 self.mark_data.remove(ind) 

934 elif key != 'control': 

935 self.mark_data.append(ind) 

936 elif ax in self.hist_ax: 

937 r = self.hist_indices[self.hist_ax.index(ax)] 

938 # from histogram: 

939 for ind, x in enumerate(self.data[:, r]): 

940 if x >= x0 and x <= x1: 

941 if ind in self.mark_data: 

942 if key == 'control': 

943 self.mark_data.remove(ind) 

944 elif key != 'control': 

945 self.mark_data.append(ind) 

946 

947 

948 def _update_selection(self): 

949 """Highlight selected points in the scatter plots and plot corresponding waveforms.""" 

950 # update scatter plots: 

951 for artist, (c, r) in zip(self.scatter_artists, self.scatter_indices): 

952 if artist is not None: 

953 if len(self.mark_data) == 0: 

954 artist.set_offsets(np.zeros((0, 2))) 

955 else: 

956 artist.set_offsets(list(zip(self.data[self.mark_data, c], 

957 self.data[self.mark_data, r]))) 

958 artist.set_facecolors(self.data_colors[self.mark_data]) 

959 # waveform plots: 

960 if len(self.wave_ax) > 0: 

961 axdi = 0 

962 axti = 1 

963 for xi, ax in enumerate(self.wave_ax): 

964 ax.clear() 

965 if len(self.mark_data) > 0: 

966 for idx in self.mark_data: 

967 if self.wave_nested: 

968 data = self.wave_data[idx][axdi] 

969 else: 

970 data = self.wave_data[idx] 

971 if data is not None: 

972 ax.plot(data[:, 0], data[:, axti], 

973 c=self.data_colors[idx], 

974 picker=self.pick_radius) 

975 axti += 1 

976 if self.wave_has_xticks[xi]: 

977 ax.set_xlabel(self.wave_xlabels[axdi]) 

978 axti = 1 

979 axdi += 1 

980 #else: 

981 # ax.xaxis.set_major_formatter(plt.NullFormatter()) 

982 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): 

983 ax.set_ylabel(ylabel) 

984 if not isinstance(self.wave_title, bool) and self.wave_title: 

985 self.wave_ax[0].set_title(self.wave_title) 

986 self.fix_waveform_plot(self.wave_ax, self.mark_data) 

987 self.fig.canvas.draw() 

988 

989 

990 def _set_limits(self, ax, x0, x1, y0, y1): 

991 if ax in self.hist_ax: 

992 ax.set_xlim(x0, x1) 

993 for hax in self.hist_ax: 

994 hax.set_ylim(y0, y1) 

995 cc = self.hist_indices[self.hist_ax.index(ax)] 

996 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): 

997 if c == cc: 

998 sax.set_xlim(x0, x1) 

999 if r == cc: 

1000 sax.set_ylim(x0, x1) 

1001 if ax in self.scatter_ax: 

1002 idx = self.scatter_ax.index(ax) 

1003 cc, rr = self.scatter_indices[idx] 

1004 self.hist_ax[cc].set_xlim(x0, x1) 

1005 self.hist_ax[rr].set_xlim(y0, y1) 

1006 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): 

1007 if c == cc: 

1008 sax.set_xlim(x0, x1) 

1009 if c == rr: 

1010 sax.set_xlim(y0, y1) 

1011 if r == cc: 

1012 sax.set_ylim(x0, x1) 

1013 if r == rr: 

1014 sax.set_ylim(y0, y1) 

1015 

1016 

1017 def _on_key(self, event): 

1018 """Handle key events.""" 

1019 #print('pressed', event.key) 

1020 if event.key in ['left', 'right', 'up', 'down']: 

1021 if self.magnified_on: 

1022 mc, mr = self.scatter_indices[-1] 

1023 if event.key == 'left': 

1024 if mc > 0: 

1025 self.scatter_indices[-1][0] -= 1 

1026 elif mr > 1: 

1027 if mr >= self.data.shape[1]: 

1028 self.scatter_indices[-1][1] = self.maxcols - 1 

1029 else: 

1030 self.scatter_indices[-1][1] -= 1 

1031 self.scatter_indices[-1][0] = self.scatter_indices[-1][1] - 1 

1032 else: 

1033 self.scatter_indices[-1][0] = self.data.shape[1] - 1 

1034 self.scatter_indices[-1][1] = self.data.shape[1] 

1035 elif event.key == 'right': 

1036 if mc < mr - 1 and mc < self.maxcols - 1: 

1037 self.scatter_indices[-1][0] += 1 

1038 elif mr < self.maxcols: 

1039 self.scatter_indices[-1][0] = 0 

1040 self.scatter_indices[-1][1] += 1 

1041 if mr >= self.maxcols: 

1042 self.scatter_indices[-1][1] = self.data.shape[1] 

1043 else: 

1044 self.scatter_indices[-1][0] = 0 

1045 self.scatter_indices[-1][1] = 1 

1046 elif event.key == 'up': 

1047 if mr > mc + 1: 

1048 if mr >= self.data.shape[1]: 

1049 self.scatter_indices[-1][1] = self.maxcols - 1 

1050 else: 

1051 self.scatter_indices[-1][1] -= 1 

1052 elif mc > 0: 

1053 self.scatter_indices[-1][0] -= 1 

1054 self.scatter_indices[-1][1] = self.data.shape[1] 

1055 else: 

1056 self.scatter_indices[-1][0] = self.data.shape[1] - 1 

1057 self.scatter_indices[-1][1] = self.data.shape[1] 

1058 elif event.key == 'down': 

1059 if mr < self.maxcols: 

1060 self.scatter_indices[-1][1] += 1 

1061 if mr >= self.maxcols: 

1062 self.scatter_indices[-1][1] = self.data.shape[1] 

1063 elif mc < self.maxcols - 1: 

1064 self.scatter_indices[-1][0] += 1 

1065 self.scatter_indices[-1][1] = mc + 2 

1066 if self.scatter_indices[-1][1] >= self.maxcols: 

1067 self.scatter_indices[-1][1] = self.data.shape[1] 

1068 else: 

1069 self.scatter_indices[-1][0] = 0 

1070 self.scatter_indices[-1][1] = 1 

1071 for k in reversed(range(len(self.zoom_stack))): 

1072 if self.zoom_stack[k][0] == self.scatter_ax[-1]: 

1073 del self.zoom_stack[k] 

1074 self.scatter_ax[-1].clear() 

1075 self.scatter_ax[-1].set_visible(True) 

1076 self.magnified_on = True 

1077 self._set_magnified_pos(self.fig.get_window_extent().width, 

1078 self.fig.get_window_extent().height) 

1079 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1080 self._plot_scatter(self.scatter_ax[-1], True) 

1081 else: 

1082 self._plot_hist(self.scatter_ax[-1], True) 

1083 self.fig.canvas.draw() 

1084 else: 

1085 if event.key == 'escape': 

1086 if len(self.scatter_ax) >= self.data.shape[1]: 

1087 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1088 self.magnified_on = False 

1089 self.scatter_ax[-1].set_visible(False) 

1090 self.fig.canvas.draw() 

1091 elif event.key in 'oz': 

1092 self.select_zooms = not self.select_zooms 

1093 elif event.key == 'backspace': 

1094 if len(self.zoom_stack) > 0: 

1095 self._set_limits(*self.zoom_stack.pop()) 

1096 self.fig.canvas.draw() 

1097 elif event.key in '+=': 

1098 self.pick_radius *= 1.5 

1099 elif event.key in '-': 

1100 if self.pick_radius > 5.0: 

1101 self.pick_radius /= 1.5 

1102 elif event.key in '0': 

1103 self.pick_radius = 4.0 

1104 elif event.key in ['pageup', 'pagedown', '<', '>']: 

1105 if event.key in ['pageup', '<'] and self.maxcols > 2: 

1106 self.maxcols -= 1 

1107 elif event.key in ['pagedown', '>'] and self.maxcols < self.raw_data.shape[1]: 

1108 self.maxcols += 1 

1109 for ax in self.hist_ax: 

1110 self._plot_hist(ax, False) 

1111 self._update_layout() 

1112 elif event.key == 'w': 

1113 if len(self.wave_data) > 0: 

1114 if self.maxcols > 0: 

1115 self.all_maxcols[self.show_mode] = self.maxcols 

1116 self.maxcols = 0 

1117 else: 

1118 self.maxcols = self.all_maxcols[self.show_mode] 

1119 self._set_layout(self.fig.get_window_extent().width, 

1120 self.fig.get_window_extent().height) 

1121 self.fig.canvas.draw() 

1122 elif event.key == 'ctrl+a': 

1123 self.mark_data = list(range(len(self.data))) 

1124 self._update_selection() 

1125 elif event.key in 'cC': 

1126 if event.key in 'c': 

1127 self.color_index -= 1 

1128 if self.color_index < 0: 

1129 self.color_set_index -= 1 

1130 if self.color_set_index < -1: 

1131 self.color_set_index = len(self.all_data)-1 

1132 if self.color_set_index >= 0: 

1133 if self.all_data[self.color_set_index] is None: 

1134 self.compute_pca(self.color_set_index>1, True) 

1135 self.color_index = self.all_data[self.color_set_index].shape[1]-1 

1136 else: 

1137 self.color_index = 0 if self.extra_colors is None else 1 

1138 self._set_color_column() 

1139 else: 

1140 self.color_index += 1 

1141 if (self.color_set_index >= 0 and \ 

1142 self.color_index >= self.all_data[self.color_set_index].shape[1]) or \ 

1143 (self.color_set_index < 0 and \ 

1144 self.color_index >= (1 if self.extra_colors is None else 2)): 

1145 self.color_index = 0 

1146 self.color_set_index += 1 

1147 if self.color_set_index >= len(self.all_data): 

1148 self.color_set_index = -1 

1149 elif self.all_data[self.color_set_index] is None: 

1150 self.compute_pca(self.color_set_index>1, True) 

1151 self._set_color_column() 

1152 for ax in self.scatter_ax: 

1153 ax.collections[0].set_facecolors(self.data_colors) 

1154 for a in self.scatter_artists: 

1155 if a is not None: 

1156 a.set_facecolors(self.data_colors[self.mark_data]) 

1157 for ax in self.wave_ax: 

1158 for l, c in zip(ax.lines, self.data_colors[self.mark_data]): 

1159 l.set_color(c) 

1160 l.set_markerfacecolor(c) 

1161 self._plot_scatter(self.scatter_ax[0], False, self.cbax) 

1162 self.fix_scatter_plot(self.cbax, self.color_values, 

1163 self.color_label, 'c') 

1164 self.fig.canvas.draw() 

1165 elif event.key in 'nN': 

1166 if event.key in 'N': 

1167 self.hist_nbins = (self.hist_nbins*3)//2 

1168 elif self.hist_nbins >= 15: 

1169 self.hist_nbins = (self.hist_nbins*2)//3 

1170 else: 

1171 self.hist_nbins = 10 

1172 for ax in self.hist_ax: 

1173 self._plot_hist(ax, False) 

1174 self._set_hist_ylim() 

1175 if self.scatter_indices[-1][1] >= self.data.shape[1]: 

1176 self._plot_hist(self.scatter_ax[-1], True, True) 

1177 elif not self.scatter: 

1178 self._plot_scatter(self.scatter_ax[-1], True) 

1179 if not self.scatter: 

1180 for ax in self.scatter_ax[:-1]: 

1181 self._plot_scatter(ax, False) 

1182 self.fig.canvas.draw() 

1183 elif event.key in 'H': 

1184 self.scatter = not self.scatter 

1185 for ax in self.scatter_ax[:-1]: 

1186 self._plot_scatter(ax, False) 

1187 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1188 self._plot_scatter(self.scatter_ax[-1], True) 

1189 self.fig.canvas.draw() 

1190 elif event.key in 'pP': 

1191 if len(self.scatter_ax) >= self.data.shape[1]: 

1192 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1193 self.scatter_indices[-1] = [0, 1] 

1194 self.magnified_on = False 

1195 self.scatter_ax[-1].set_visible(False) 

1196 self.all_maxcols[self.show_mode] = self.maxcols 

1197 if event.key == 'P': 

1198 self.show_mode += 1 

1199 if self.show_mode >= len(self.all_data): 

1200 self.show_mode = 0 

1201 else: 

1202 self.show_mode -= 1 

1203 if self.show_mode < 0: 

1204 self.show_mode = len(self.all_data)-1 

1205 if self.show_mode == 1: 

1206 print('principal components') 

1207 elif self.show_mode == 2: 

1208 print('scaled principal components') 

1209 else: 

1210 print('data') 

1211 if self.all_data[self.show_mode] is None: 

1212 self.compute_pca(self.show_mode>1, True) 

1213 self.data = self.all_data[self.show_mode] 

1214 self.labels = self.all_labels[self.show_mode] 

1215 self.maxcols = self.all_maxcols[self.show_mode] 

1216 self.zoom_stack = [] 

1217 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) 

1218 for ax in self.hist_ax: 

1219 self._plot_hist(ax, False) 

1220 self._set_hist_ylim() 

1221 for ax in self.scatter_ax: 

1222 self._plot_scatter(ax, False) 

1223 self._update_layout() 

1224 elif event.key in 'l': 

1225 if len(self.mark_data) > 0: 

1226 self.list_selection(self.mark_data) 

1227 elif event.key in 'h': 

1228 self.help_ax.set_visible(not self.help_ax.get_visible()) 

1229 self.fig.canvas.draw() 

1230 

1231 

1232 def _on_select(self, eclick, erelease): 

1233 """Handle selection events.""" 

1234 if eclick.dblclick: 

1235 if len(self.mark_data) > 0: 

1236 self.analyze_selection(self.mark_data[-1]) 

1237 return 

1238 x0 = min(eclick.xdata, erelease.xdata) 

1239 x1 = max(eclick.xdata, erelease.xdata) 

1240 y0 = min(eclick.ydata, erelease.ydata) 

1241 y1 = max(eclick.ydata, erelease.ydata) 

1242 ax = erelease.inaxes 

1243 if ax is None: 

1244 ax = eclick.inaxes 

1245 xmin, xmax = ax.get_xlim() 

1246 ymin, ymax = ax.get_ylim() 

1247 dx = 0.02*(xmax-xmin) 

1248 dy = 0.02*(ymax-ymin) 

1249 if x1 - x0 < dx and y1 - y0 < dy: 

1250 bbox = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted()) 

1251 width, height = bbox.width, bbox.height 

1252 width *= self.fig.dpi 

1253 height *= self.fig.dpi 

1254 dx = self.pick_radius*(xmax-xmin)/width 

1255 dy = self.pick_radius*(ymax-ymin)/height 

1256 x0 = erelease.xdata - dx 

1257 x1 = erelease.xdata + dx 

1258 y0 = erelease.ydata - dy 

1259 y1 = erelease.ydata + dy 

1260 elif self.select_zooms: 

1261 self.zoom_stack.append((ax, xmin, xmax, ymin, ymax)) 

1262 self._set_limits(ax, x0, x1, y0, y1) 

1263 self._make_selection(ax, erelease.key, x0, x1, y0, y1) 

1264 self._update_selection() 

1265 

1266 

1267 def _on_pick(self, event): 

1268 """Handle pick events.""" 

1269 for ax in self.wave_ax: 

1270 for k, l in enumerate(ax.lines): 

1271 if l is event.artist: 

1272 self.mark_data = [self.mark_data[k]] 

1273 for ax in self.scatter_ax: 

1274 if ax.collections[0] is event.artist: 

1275 self.mark_data = event.ind 

1276 self._update_selection() 

1277 if event.mouseevent.dblclick: 

1278 if len(self.mark_data) > 0: 

1279 self.analyze_selection(self.mark_data[-1]) 

1280 

1281 

1282 def _set_layout(self, width, height): 

1283 """Update positions and visibility of all plots.""" 

1284 xoffs = self.xborder/width 

1285 yoffs = self.yborder/height 

1286 xs = self.spacing/width 

1287 ys = self.spacing/height 

1288 if self.maxcols > 0: 

1289 dx = (1.0-xoffs)/self.maxcols 

1290 dy = (1.0-yoffs)/self.maxcols 

1291 xw = dx - xs 

1292 yw = dy - ys 

1293 # histograms: 

1294 for c, ax in enumerate(self.hist_ax): 

1295 if c < self.maxcols: 

1296 ax.set_position([xoffs+c*dx, yoffs, xw, yw]) 

1297 ax.set_visible(True) 

1298 else: 

1299 ax.set_visible(False) 

1300 ax.set_position([0.99, 0.01, 0.01, 0.01]) 

1301 # scatter plots: 

1302 for ax, (c, r) in zip(self.scatter_ax[:-1], self.scatter_indices[:-1]): 

1303 if r < self.maxcols: 

1304 ax.set_position([xoffs+c*dx, yoffs+(self.maxcols-r)*dy, xw, yw]) 

1305 ax.set_visible(True) 

1306 else: 

1307 ax.set_visible(False) 

1308 ax.set_position([0.99, 0.01, 0.01, 0.01]) 

1309 # color bar: 

1310 if self.maxcols > 0: 

1311 self.cbax.set_position([xoffs+dx, yoffs+(self.maxcols-1)*dy, 0.3*xoffs, yw]) 

1312 self.cbax.set_visible(True) 

1313 else: 

1314 self.cbax.set_visible(False) 

1315 self.cbax.set_position([0.99, 0.01, 0.01, 0.01]) 

1316 # magnified plot: 

1317 if self.maxcols > 0: 

1318 self._set_magnified_pos(width, height) 

1319 if self.magnified_backdrop is not None: 

1320 bbox = self.scatter_ax[-1].get_tightbbox(self.fig.canvas.get_renderer()) 

1321 if bbox is not None: 

1322 self.magnified_backdrop.set_bounds(bbox.x0 - self.mborder, 

1323 bbox.y0 - self.mborder, 

1324 bbox.width + 2*self.mborder, 

1325 bbox.height + 2*self.mborder) 

1326 else: 

1327 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1328 self.scatter_ax[-1].set_visible(False) 

1329 # waveform plots: 

1330 if len(self.wave_ax) > 0: 

1331 if self.maxcols > 0: 

1332 x0 = xoffs+((self.maxcols+1)//2)*dx 

1333 y0 = ((self.maxcols+1)//2)*dy 

1334 if self.maxcols%2 == 0: 

1335 x0 += xoffs 

1336 y0 += yoffs - ys 

1337 else: 

1338 y0 += ys 

1339 else: 

1340 x0 = xoffs 

1341 y0 = 0.0 

1342 yp = 1.0 

1343 dy = 1.0-y0 

1344 dy -= np.sum(self.wave_has_xticks)*yoffs 

1345 yp -= ys 

1346 dy -= ys 

1347 if self.wave_title: 

1348 yp -= 2*ys 

1349 dy -= 2*ys 

1350 dy /= len(self.wave_ax) 

1351 for ax, has_xticks in zip(self.wave_ax, self.wave_has_xticks): 

1352 yp -= dy 

1353 ax.set_position([x0, yp, 1.0-x0-xs, dy]) 

1354 if has_xticks: 

1355 yp -= yoffs 

1356 else: 

1357 yp -= ys 

1358 

1359 

1360 def _update_layout(self): 

1361 """Update content and position of magnified plot.""" 

1362 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1363 if self.scatter_indices[-1][1] >= self.maxcols: 

1364 self.scatter_indices[-1][1] = self.maxcols-1 

1365 if self.scatter_indices[-1][0] >= self.scatter_indices[-1][1]: 

1366 self.scatter_indices[-1][0] = self.scatter_indices[-1][1]-1 

1367 self._plot_scatter(self.scatter_ax[-1], True) 

1368 else: 

1369 if self.scatter_indices[-1][0] >= self.maxcols: 

1370 self.scatter_indices[-1][0] = self.maxcols-1 

1371 self._plot_hist(self.scatter_ax[-1], True) 

1372 self._set_hist_ylim() 

1373 self._set_layout(self.fig.get_window_extent().width, 

1374 self.fig.get_window_extent().height) 

1375 self.fig.canvas.draw() 

1376 

1377 

1378 def _on_resize(self, event): 

1379 """Adapt layout of plots to new figure size.""" 

1380 self._set_layout(event.width, event.height) 

1381 

1382 

1383def categorize(data): 

1384 """Convert categorial string data into integer categories. 

1385 

1386 Parameters 

1387 ---------- 

1388 data: list or ndarray of str 

1389 Data with textual categories. 

1390 

1391 Returns 

1392 ------- 

1393 categories: list of str 

1394 A sorted unique list of the strings in `data`, 

1395 that is the names of the categories. 

1396 cdata: ndarray of int 

1397 A copy of the input `data` where each string value is replaced 

1398 by an integer number that is an index into the returned `categories`. 

1399 """ 

1400 cats = sorted(set(data)) 

1401 cdata = np.array([cats.index(x) for x in data], dtype=int) 

1402 return cats, cdata 

1403 

1404 

1405def select_features(data, columns): 

1406 """Assemble list of column indices. 

1407 

1408 Parameters 

1409 ---------- 

1410 data: TableData 

1411 The table from which to select features. 

1412 columns: list of str 

1413 Feature names (column headers) to be selected from the data. 

1414 If a column is listed twice (even times) it is not added. 

1415 

1416 Returns 

1417 ------- 

1418 data_cols: list of int 

1419 List of indices into data columns for selecting features. 

1420 """ 

1421 if len(columns) == 0: 

1422 data_cols = list(np.arange(len(data))) 

1423 else: 

1424 data_cols = [] 

1425 for c in columns: 

1426 idx = data.index(c) 

1427 if idx is None: 

1428 print(f'"{c}" is not a valid data column') 

1429 elif idx in data_cols: 

1430 data_cols.remove(idx) 

1431 else: 

1432 data_cols.append(idx) 

1433 return data_cols 

1434 

1435 

1436def select_coloring(data, data_cols, color_col): 

1437 """Select column from data table for colorizing the data. 

1438 

1439 Pass the output of this function on to MultivariateExplorer.set_colors(). 

1440 

1441 Parameters 

1442 ---------- 

1443 data: TableData 

1444 Table with all EOD properties from which columns are selected. 

1445 data_cols: list of int 

1446 List of columns selected to be explored. 

1447 color_col: str or int 

1448 Column to be selected for coloring the data. 

1449 If 'row' then use the row index of the data in the table for coloring. 

1450 

1451 Returns 

1452 ------- 

1453 colors: int or list of values or None 

1454 Either index of `data_cols` or additional data from the data table 

1455 to be used for coloring. 

1456 color_label: str or None 

1457 Label for labeling the color bar. 

1458 color_idx: int or None 

1459 Index of color column in `data`. 

1460 error: None or str 

1461 In case an invalid column is selected, an error string. 

1462 """ 

1463 color_idx = data.index(color_col) 

1464 colors = None 

1465 color_label = None 

1466 if color_idx is None and color_col != 'row': 

1467 return None, None, None, f'"{color_col}" is not a valid column for color code' 

1468 if color_idx is None: 

1469 colors = -2 

1470 elif color_idx in data_cols: 

1471 colors = data_cols.index(color_idx) 

1472 else: 

1473 if len(data.unit(color_idx)) > 0 and not data.unit(color_idx) in ['-', '1']: 

1474 color_label = f'{data.label(color_idx)} [{data.unit(color_idx)}]' 

1475 else: 

1476 color_label = data.label(color_idx) 

1477 colors = data[:, color_idx] 

1478 return colors, color_label, color_idx, None 

1479 

1480 

1481def list_available_features(data, data_cols=[], color_col=None): 

1482 """Print available features on console. 

1483 

1484 Parameters 

1485 ---------- 

1486 data: TableData 

1487 The full data table. 

1488 data_cols: list of int 

1489 List of indices of columns (features) in the table 

1490 that are passed on to the MultivariateExplorer. 

1491 color_col: int or None 

1492 Index of columns (feature) in the table 

1493 that is initially used for color coding the data. 

1494 """ 

1495 print('available features:') 

1496 for k, c in enumerate(data.keys()): 

1497 s = [' '] * 3 

1498 if k in data_cols: 

1499 s[1] = '*' 

1500 if color_col is not None and k == color_col: 

1501 s[0] = 'C' 

1502 print(''.join(s) + c) 

1503 if len(data_cols) > 0: 

1504 print('*: feature selected for exploration') 

1505 if color_col is not None: 

1506 print('C: feature selected for color coding the data') 

1507 

1508 

1509class PrintHelp(argparse.Action): 

1510 def __call__(self, parser, namespace, values, option_string): 

1511 parser.print_help() 

1512 print('') 

1513 print('mouse:') 

1514 for ma in MultivariateExplorer.mouse_actions: 

1515 print('%-23s %s' % ma) 

1516 print('%-23s %s' % ('double left click', 'run thunderfish on selected EOD waveform')) 

1517 print('') 

1518 print('key shortcuts:') 

1519 for ka in MultivariateExplorer.key_actions: 

1520 print('%-23s %s' % ka) 

1521 parser.exit() 

1522 

1523 

1524def demo(): 

1525 """Run the multivariate explorer with a random test data set. 

1526 """ 

1527 # generate data: 

1528 n = 100 

1529 data = [] 

1530 data.append(np.random.randn(n) + 2.0) 

1531 data.append(1.0+0.1*data[0] + 1.5*np.random.randn(n)) 

1532 data.append(10*(-3.0*data[0] + 2.0*data[1] + 1.8*np.random.randn(n))) 

1533 idx = np.random.randint(0, 3, n) 

1534 names = ['aaa', 'bbb', 'ccc'] 

1535 data.append([names[i] for i in idx]) 

1536 # generate waveforms: 

1537 waveforms = [] 

1538 time = np.arange(0.0, 10.0, 0.01) 

1539 for r in range(len(data[0])): 

1540 x = data[0][r]*np.sin(2.0*np.pi*data[1][r]*time + data[2][r]) 

1541 y = data[0][r]*np.exp(-0.5*((time-data[1][r])/(0.2*data[2][r]))**2.0) 

1542 waveforms.append(np.column_stack((time, x, y))) 

1543 #waveforms.append([np.column_stack((time, x)), np.column_stack((time, y))]) 

1544 # initialize explorer: 

1545 expl = MultivariateExplorer(data, 

1546 list(map(chr, np.arange(len(data))+ord('A'))), 

1547 'Explorer') 

1548 expl.set_wave_data(waveforms, 'Time', ['Sine', 'Gauss']) 

1549 # explore data: 

1550 expl.set_colors() 

1551 expl.show() 

1552 

1553 

1554def main(*cargs): 

1555 # parse command line: 

1556 parser = argparse.ArgumentParser(add_help=False, 

1557 description='View and explore multivariate data.', 

1558 epilog = f'version {__version__} by Benda-Lab (2019-{__year__})') 

1559 parser.add_argument('-h', '--help', nargs=0, action=PrintHelp, 

1560 help='show this help message and exit') 

1561 parser.add_argument('--version', action='version', version=__version__) 

1562 parser.add_argument('-l', dest='list_features', action='store_true', 

1563 help='list all available data columns (features) and exit') 

1564 parser.add_argument('-d', dest='data_cols', action='append', 

1565 default=[], metavar='COLUMN', 

1566 help='data columns (features) to be explored') 

1567 parser.add_argument('-c', dest='color_col', default=None, 

1568 type=str, metavar='COLUMN', 

1569 help='data column to be used for color code or "row"') 

1570 parser.add_argument('-m', dest='color_map', default='jet', 

1571 type=str, metavar='CMAP', 

1572 help='name of color map to be used') 

1573 parser.add_argument('file', nargs='?', default='', type=str, 

1574 help='a file containing a table of data (csv file or similar)') 

1575 if len(cargs) == 0: 

1576 cargs = None 

1577 args = parser.parse_args(cargs) 

1578 if args.file: 

1579 # load data: 

1580 data = TableData(args.file) 

1581 data_cols = select_features(data, args.data_cols) 

1582 # select column used for coloring the data: 

1583 colors, color_label, color_col, error = \ 

1584 select_coloring(data, data_cols, args.color_col) 

1585 if error: 

1586 parser.error(error) 

1587 # list features: 

1588 if args.list_features: 

1589 list_available_features(data, data_cols, color_col) 

1590 parser.exit() 

1591 # explore data: 

1592 expl = MultivariateExplorer(data[:, data_cols]) 

1593 expl.set_colors(colors, color_label, args.color_map) 

1594 expl.show() 

1595 else: 

1596 demo() 

1597 

1598 

1599if __name__ == '__main__': 

1600 main(*sys.argv[1:])