Coverage for src/thunderlab/multivariateexplorer.py: 88%

981 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-06-26 11:35 +0000

1"""Simple GUI for viewing and exploring multivariate data. 

2 

3- `class MultiVariateExplorer`: simple matplotlib-based GUI for viewing and exploring multivariate data. 

4- `categorize()`: convert categorial string data into integer categories. 

5- `select_features()`: assemble list of column indices. 

6- `select_coloring()`: select column from data table for colorizing the data. 

7- `list_available_features()`: print available features on console. 

8""" 

9 

10import sys 

11import numpy as np 

12from scipy.stats import pearsonr 

13from sklearn import decomposition 

14from sklearn import preprocessing 

15import matplotlib.pyplot as plt 

16import matplotlib.patches as patches 

17import matplotlib.widgets as widgets 

18import argparse 

19from .version import __version__, __year__ 

20from .tabledata import TableData 

21 

22 

23class MultivariateExplorer(object): 

24 """Simple matplotlib-based GUI for viewing and exploring multivariate data. 

25 

26 Shown are scatter plots of all pairs of variables or PCA axis. 

27 Points in the scatter plots are colored according to the values of one of the variables. 

28 Data points can be selected and optionally corresponding waveforms are shown. 

29 

30 First you initialize the explorer with the data. Then you optionally 

31 specify how to colorize the data and provide waveform data 

32 associated with the data. Finally you show the figure: 

33 ``` 

34 expl = MultivariateExplorer(data) 

35 expl.set_colors(2) 

36 expl.set_wave_data(waveforms, 'Time [s]', 'Sine') 

37 expl.show() 

38 ``` 

39 

40 The `compute_pca() function computes a principal component analysis (PCA) 

41 on the input data, and `save_pca()` writes the principal components to a file. 

42 

43 Customize the appearance and information provided by subclassing 

44 MultivariateExplorer and reimplementing the functions 

45 - fix_scatter_plot() 

46 - fix_waveform_plot() 

47 - list_selection() 

48 - analyze_selection() 

49 See the documentation of these functions for details. 

50 """ 

51 

52 mouse_actions = [ 

53 ('left click', 'select sample'), 

54 ('left and drag', 'rectangular selection of samples and/or zoom'), 

55 ('shift + left click/drag', 'add samples to selection'), 

56 ('ctrl + left click/drag', 'remove samples from selection') 

57 ] 

58 """List of tuples with mouse actions and a description of their action.""" 

59 

60 key_actions = [ 

61 ('c, C', 'cycle color map trough data columns'), 

62 ('p,P', 'toggle between features, PCs, and scaled PCs'), 

63 ('<, pageup', 'decrease number of displayed featured/PCs'), 

64 ('>, pagedown', 'increase number of displayed features/PCs'), 

65 ('o, z', 'toggle zoom mode on or off'), 

66 ('backspace', 'zoom back'), 

67 ('n, N', 'decrease, increase number of bins of histograms'), 

68 ('H', 'toggle between scatter plot and 2D histogram'), 

69 ('left, right, up, down', 'show and move magnified scatter plot'), 

70 ('escape', 'close magnified scatter plot'), 

71 ('ctrl + a', 'select all'), 

72 ('+, -', 'increase, decrease pick radius'), 

73 ('0', 'reset pick radius'), 

74 ('l', 'list selection on console'), 

75 ('w', 'toggle maximized waveform plot'), 

76 ('h', 'toggle help window'), 

77 ] 

78 """List of tuples with key shortcuts and a description of their action.""" 

79 

80 def __init__(self, data, labels=None, title=None): 

81 """Initialize explorer with scatter-plot data. 

82 

83 Parameters 

84 ---------- 

85 data: TableData, 2D array, or list of 1D arrays 

86 The data to be explored. Each column is a variable. 

87 For the 2D array the columns are the second dimension, 

88 for a list of 1D arrays, the list goes over columns, 

89 i.e. each 1D array is one column. 

90 labels: list of str 

91 If data is not a TableData, then this provides labels 

92 for the data columns. 

93 title: str 

94 Title for the window. 

95 """ 

96 # data. categories and labels: 

97 self.raw_data = None # original data table as 2D numpy array (samples x features) 

98 self.raw_labels = None # for each feature a label, optional with unit 

99 self.categories = [] # for each feature None or list of categories 

100 if isinstance(data, TableData): 

101 for c, col in enumerate(data): 

102 if not isinstance(col[0], (int, float, 

103 np.integer, np.floating)): 

104 # categorial data: 

105 #print(data[:,c]) 

106 cats, data[:,c] = categorize(col) 

107 self.categories.append(cats) 

108 else: 

109 self.categories.append(None) 

110 self.raw_data = data.array() 

111 if labels is None: 

112 self.raw_labels = [] 

113 for c in range(len(data)): 

114 if len(data.unit(c)) > 0 and not data.unit(c) in ['-', '1']: 

115 self.raw_labels.append(f'{data.label(c)} [{data.unit(c)}]') 

116 else: 

117 self.raw_labels.append(data.label(c)) 

118 else: 

119 self.raw_labels = labels 

120 else: 

121 if isinstance(data, np.ndarray): 

122 self.raw_data = data 

123 self.categories = [None] * data.shape[1] 

124 else: 

125 for c, col in enumerate(data): 

126 if not isinstance(col[0], (int, float, 

127 np.integer, np.floating)): 

128 # categorial data: 

129 cats, data[c] = categorize(col) 

130 self.categories.append(cats) 

131 else: 

132 self.categories.append(None) 

133 self.raw_data = np.asarray(data).T 

134 self.raw_labels = labels 

135 # remove columns containing only invalid numbers: 

136 cols = np.all(~np.isfinite(self.raw_data), 0) 

137 if np.sum(cols) > 0: 

138 print('removed columns containing no numbers:', 

139 [l for l, c in zip(self.raw_labels, cols) if c]) 

140 self.raw_data = self.raw_data[:, ~cols] 

141 self.raw_labels = [l for l, c in zip(self.raw_labels, cols) if not c] 

142 # remove rows containing invalid numbers: 

143 self.valid_samples = ~np.any(~np.isfinite(self.raw_data), 1) 

144 self.raw_data = self.raw_data[self.valid_samples, :] 

145 if np.sum(~self.valid_samples) > 0: 

146 print(f'removed {np.sum(~self.valid_samples)} rows containing invalid numbers:') 

147 for k in range(len(self.valid_samples)): 

148 if not self.valid_samples[k]: 

149 print(k) 

150 self.valid_rows = [k for k in range(len(self.valid_samples)) 

151 if self.valid_samples[k]] 

152 # title for the window: 

153 self.title = title if title is not None else 'MultivariateExplorer' 

154 # data, pca-data, scaled-pca data (no pca data yet): 

155 self.all_data = [self.raw_data, None, None] 

156 self.all_labels = [self.raw_labels, None, None] 

157 self.all_maxcols = [self.raw_data.shape[1], None, None] 

158 self.all_titles = ['data', 'PCA', 'scaled PCA'] # added to window title  

159 # pca: 

160 self.pca_tables = [None, None] # raw and scaled pca coefficients 

161 self._pca_header(self.raw_data, self.raw_labels) # prepare header of the pca tables 

162 # start showing raw data: 

163 self.show_mode = 0 # show data, pca or scaled pca 

164 self.data = self.all_data[self.show_mode] # the data shown 

165 self.labels = self.all_labels[self.show_mode] # the feature labels shown 

166 self.maxcols = self.all_maxcols[self.show_mode] # maximum number of features currently shown 

167 if self.maxcols > 6: 

168 self.maxcols = 6 

169 # waveform data: 

170 self.wave_data = [] 

171 self.wave_nested = False 

172 self.wave_has_xticks = [] 

173 self.wave_xlabels = [] 

174 self.wave_ylabels = [] 

175 self.wave_title = False 

176 # colors: 

177 self.color_map = plt.get_cmap('jet') 

178 self.extra_colors = None # additional data column to be used for coloring 

179 self.extra_color_label = None # label for extra_colors 

180 self.extra_categories = None # category name for extra_colors if needed 

181 self.color_set_index = 0 # -1: rows and extra_colors, 0: data, 1: pca, 2: scaled pca 

182 self.color_index = 0 # column used for coloring with color_set_index 

183 self.color_values = None # data column currently used for coloring as specified by color_set_index and color_index 

184 self.color_label = None # label of data currently used for coloring 

185 self.data_colors = None # actual colors for color_values 

186 self.color_vmin = None 

187 self.color_vmax = None 

188 self.color_ticks = None 

189 self.cbax = None # axes of color bar 

190 # figure variables: 

191 self.plt_params = {} 

192 for k in ['toolbar', 'keymap.quit', 'keymap.back', 'keymap.forward', 

193 'keymap.zoom', 'keymap.pan', 'keymap.xscale', 'keymap.yscale']: 

194 self.plt_params[k] = plt.rcParams[k] 

195 if k != 'toolbar': 

196 plt.rcParams[k] = '' 

197 self.xborder = 100.0 # pixel for ylabels 

198 self.yborder = 50.0 # pixel for xlabels 

199 self.spacing = 10.0 # pixel between plots 

200 self.mborder = 20.0 # pixel around magnified plot 

201 self.pick_radius = 4.0 

202 # histogram plots: 

203 self.hist_ax = [] # list of histogram axes 

204 self.hist_indices = [] # feature index of the histogram axes 

205 self.hist_selector = [] # for each histogram axes a selector 

206 self.hist_nbins = 30 # number of bins for computing histograms 

207 # scatter plots: 

208 self.scatter_ax = [] # list of axes with scatter plots (1D) 

209 self.scatter_indices = [] # for each axes a tuple of the column and row index 

210 self.scatter_artists = [] # artists of selected scatter points 

211 self.scatter_selector = [] # selector for each axes 

212 self.scatter = True # scatter (True) or density (False) 

213 self.mark_data = [] # list of indices of selected data 

214 self.significance_level = 0.05 # r is bold if p is below 

215 self.select_zooms = False 

216 self.zoom_stack = [] 

217 # magnified scatter plot: 

218 self.magnified_on = False 

219 self.magnified_backdrop = None 

220 self.magnified_size = np.array([0.6, 0.6]) 

221 # waveform plots: 

222 self.wave_ax = [] 

223 # help window: 

224 self.help_ax = None 

225 

226 

227 def set_wave_data(self, data, xlabels='', ylabels=[], title=False): 

228 """Add waveform data to explorer. 

229 

230 Parameters 

231 ---------- 

232 data: list of (list of) 2D arrays 

233 Waveform data associated with each row of the data. 

234 Elements of the outer list correspond to the rows of the data. 

235 The inner 2D arrays contain a common x-axes (first column) 

236 and one or more corresponding y-values (second and optional higher columns). 

237 Each column for y-values is plotted in its own axes on top of each other, 

238 from top to bottom. 

239 The optional inner list of 2D arrays contains several 2D arrays as ascribed above 

240 each with its own common x-axes. 

241 xlabel: str or list of str 

242 The xlabels for the waveform plots. If only a string is given, then 

243 there will be a common xaxis for all the plots, and only the lowest 

244 one gets a labeled xaxis. If a list of strings is given, each waveform 

245 plot gets its own labeled x-axis. 

246 ylabels: list of str 

247 The ylabels for each of the waveform plots. 

248 title: bool or str 

249 If True or a string, povide space on top of the waveform plots for a title. 

250 If string, set this as the title for the waveform plots. 

251 """ 

252 self.wave_data = [] 

253 if data is not None and len(data) > 0: 

254 self.wave_data = data 

255 self.wave_has_xticks = [] 

256 self.wave_nested = isinstance(data[0], (list, tuple)) 

257 if self.wave_nested: 

258 for data in self.wave_data[0]: 

259 for k in range(data.shape[1]-2): 

260 self.wave_has_xticks.append(False) 

261 self.wave_has_xticks.append(True) 

262 else: 

263 for k in range(self.wave_data[0].shape[1]-2): 

264 self.wave_has_xticks.append(False) 

265 self.wave_has_xticks.append(True) 

266 if isinstance(xlabels, (list, tuple)): 

267 self.wave_xlabels = xlabels 

268 else: 

269 self.wave_xlabels = [xlabels] 

270 self.wave_ylabels = ylabels 

271 self.wave_title = title 

272 self.wave_ax = [] 

273 

274 

275 def set_colors(self, colors=0, color_label=None, color_map=None): 

276 """Set data column used to color scatter plots. 

277  

278 Parameters 

279 ---------- 

280 colors: int or 1D array 

281 Index to colum in data to be used for coloring scatter plots. 

282 -2 for coloring row index of data. 

283 Or data array used to color scalar plots. 

284 color_label: str 

285 If colors is an array, this is a label describing the data. 

286 It is used to label the color bar. 

287 color_map: str or None 

288 Name of a matplotlib color map. 

289 If None 'jet' is used. 

290 """ 

291 if isinstance(colors, (np.integer, int)): 

292 if colors < 0: 

293 self.color_set_index = -1 

294 self.color_index = 0 

295 else: 

296 self.color_set_index = 0 

297 self.color_index = colors 

298 else: 

299 if not isinstance(colors[0], (int, float, 

300 np.integer, np.floating)): 

301 # categorial data: 

302 self.extra_categories, self.extra_colors = categorize(colors) 

303 else: 

304 self.extra_colors = colors 

305 self.extra_colors = self.extra_colors[self.valid_samples] 

306 self.extra_color_label = color_label 

307 self.color_set_index = -1 

308 self.color_index = 1 

309 self.color_map = plt.get_cmap(color_map if color_map else 'jet') 

310 

311 

312 def show(self, ioff=True): 

313 """Show interactive scatter plots for exploration. 

314 """ 

315 if ioff: 

316 plt.ioff() 

317 else: 

318 plt.ion() 

319 plt.rcParams['toolbar'] = 'None' 

320 plt.rcParams['keymap.quit'] = 'ctrl+w, alt+q, ctrl+q, q' 

321 plt.rcParams['font.size'] = 12 

322 self.fig = plt.figure(facecolor='white', figsize=(10, 8)) 

323 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) 

324 self.fig.canvas.mpl_connect('key_press_event', self._on_key) 

325 self.fig.canvas.mpl_connect('resize_event', self._on_resize) 

326 self.fig.canvas.mpl_connect('pick_event', self._on_pick) 

327 if self.color_map is None: 

328 self.color_map = plt.get_cmap('jet') 

329 self._set_color_column() 

330 self._init_hist_plots() 

331 self._init_scatter_plots() 

332 self.wave_ax = [] 

333 if self.wave_data is not None and len(self.wave_data) > 0: 

334 axx = None 

335 xi = 0 

336 for k, has_xticks in enumerate(self.wave_has_xticks): 

337 ax = self.fig.add_subplot(1, len(self.wave_has_xticks), 

338 1+k, sharex=axx) 

339 self.wave_ax.append(ax) 

340 if has_xticks: 

341 if xi >= len(self.wave_xlabels): 

342 self.wave_xlabels.append('') 

343 ax.set_xlabel(self.wave_xlabels[xi]) 

344 xi += 1 

345 axx = None 

346 else: 

347 #ax.xaxis.set_major_formatter(plt.NullFormatter()) 

348 if axx is None: 

349 axx = ax 

350 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): 

351 ax.set_ylabel(ylabel) 

352 if not isinstance(self.wave_title, bool) and self.wave_title: 

353 self.wave_ax[0].set_title(self.wave_title) 

354 self.fix_waveform_plot(self.wave_ax, self.mark_data) 

355 self._plot_magnified_scatter() 

356 self._plot_help() 

357 plt.show() 

358 

359 

360 def _pca_header(self, data, labels): 

361 """Set up header for the table of principal components. 

362 

363 Parameters 

364 ---------- 

365 data: ndarray of float 

366 The data (samples x features) without invalid (infinite or 

367 NaN) numbers. 

368 labels: list of str 

369 Labels of the features. 

370 """ 

371 lbs = [] 

372 for l, d in zip(labels, data): 

373 if '[' in l: 

374 lbs.append(l.split('[')[0].strip()) 

375 elif '/' in l: 

376 lbs.append(l.split('/')[0].strip()) 

377 else: 

378 lbs.append(l) 

379 header = TableData(header=lbs) 

380 header.set_formats('%.3f') 

381 header.insert(0, ['PC'] + ['-']*header.nsecs, '', '%d') 

382 header.insert(1, 'variance', '%', '%.3f') 

383 for k in range(len(self.pca_tables)): 

384 self.pca_tables[k] = TableData(header) 

385 

386 

387 def compute_pca(self, scale=False, write=False): 

388 """Compute PCA based on the data. 

389 

390 Parameters 

391 ---------- 

392 scale: boolean 

393 If True standardize data before computing PCA, i.e. remove mean 

394 of each variabel and divide by its standard deviation. 

395 write: boolean 

396 If True write PCA components to standard out. 

397 """ 

398 # pca: 

399 pca = decomposition.PCA() 

400 if scale: 

401 scaler = preprocessing.StandardScaler() 

402 scaler.fit(self.raw_data) 

403 pca.fit(scaler.transform(self.raw_data)) 

404 pca_label = 'sPC' 

405 else: 

406 pca.fit(self.raw_data) 

407 pca_label = 'PC' 

408 for k in range(len(pca.components_)): 

409 if np.abs(np.min(pca.components_[k])) > np.max(pca.components_[k]): 

410 pca.components_[k] *= -1.0 

411 pca_data = pca.transform(self.raw_data) 

412 pca_labels = [f'{pca_label}{k+1} ' + (f'({100*v:.1f}%)' if v > 0.01 else (f'{100*v:.2f}%')) 

413 for k, v in enumerate(pca.explained_variance_ratio_)] 

414 if np.min(pca.explained_variance_ratio_) >= 0.01: 

415 pca_maxcols = pca_data.shape[1] 

416 else: 

417 pca_maxcols = np.argmax(pca.explained_variance_ratio_ < 0.01) 

418 if pca_maxcols < 2: 

419 pca_maxcols = 2 

420 if pca_maxcols > 6: 

421 pca_maxcols = 6 

422 # table with PCA feature weights: 

423 pca_table = self.pca_tables[1] if scale else self.pca_tables[0] 

424 pca_table.clear_data() 

425 pca_table.set_section(pca_label, 0, pca_table.nsecs) 

426 for k, comp in enumerate(pca.components_): 

427 pca_table.append_data(k+1, 0) 

428 pca_table.append_data(100.0*pca.explained_variance_ratio_[k]) 

429 pca_table.append_data(comp) 

430 if write: 

431 pca_table.write(table_format='out', unit_style='none') 

432 # submit data: 

433 if scale: 

434 self.all_data[2] = pca_data 

435 self.all_labels[2] = pca_labels 

436 self.all_maxcols[2] = pca_maxcols 

437 else: 

438 self.all_data[1] = pca_data 

439 self.all_labels[1] = pca_labels 

440 self.all_maxcols[1] = pca_maxcols 

441 

442 

443 def save_pca(self, file_name, scale, **kwargs): 

444 """Write PCA data to file. 

445 

446 Parameters 

447 ---------- 

448 file_name: str 

449 Name of ouput file. 

450 scale: boolean 

451 If True write PCA components of standardized PCA. 

452 kwargs: dict 

453 Additional parameter for TableData.write() 

454 """ 

455 if scale: 

456 pca_file = file_name + '-pcacor' 

457 pca_table = self.pca_tables[1] 

458 else: 

459 pca_file = file_name + '-pcacov' 

460 pca_table = self.pca_tables[0] 

461 if 'unit_style' in kwargs: 

462 del kwargs['unit_style'] 

463 if 'table_format' in kwargs: 

464 pca_table.write(pca_file, unit_style='none', **kwargs) 

465 else: 

466 pca_file += '.dat' 

467 pca_table.write(pca_file, unit_style='none') 

468 

469 

470 def _set_color_column(self): 

471 """Initialize variables used for colorization of scatter points.""" 

472 if self.color_set_index == -1: 

473 if self.color_index == 0: 

474 self.color_values = np.arange(self.data.shape[0], dtype=float) 

475 self.color_label = 'sample' 

476 elif self.color_index == 1: 

477 self.color_values = self.extra_colors 

478 self.color_label = self.extra_color_label 

479 else: 

480 self.color_values = self.all_data[self.color_set_index][:,self.color_index] 

481 self.color_label = self.all_labels[self.color_set_index][self.color_index] 

482 self.color_vmin, self.color_vmax, self.color_ticks = \ 

483 self.fix_scatter_plot(self.cbax, self.color_values, 

484 self.color_label, 'c') 

485 if self.color_ticks is None: 

486 if self.color_set_index == 0 and \ 

487 self.categories[self.color_index] is not None: 

488 self.color_ticks = np.arange(len(self.categories[self.color_index])) 

489 elif self.color_set_index == -1 and \ 

490 self.color_index == 1 and \ 

491 self.extra_categories is not None: 

492 self.color_ticks = np.arange(len(self.extra_categories)) 

493 self.data_colors = self.color_map((self.color_values - self.color_vmin)/(self.color_vmax - self.color_vmin)) 

494 

495 

496 def _add_backdrop(self, ax): 

497 bbox = ax.get_tightbbox(self.fig.canvas.get_renderer()) 

498 if bbox is not None: 

499 self.magnified_backdrop = \ 

500 patches.Rectangle((bbox.x0 - self.mborder, 

501 bbox.y0 - self.mborder), 

502 bbox.width + 2*self.mborder, 

503 bbox.height + 2*self.mborder, 

504 transform=None, clip_on=False, 

505 facecolor='#f7f7f7', edgecolor='none', 

506 zorder=-5) 

507 ax.add_patch(self.magnified_backdrop) 

508 

509 

510 def _create_selector(self, ax): 

511 try: 

512 selector = \ 

513 widgets.RectangleSelector(ax, self._on_select, 

514 useblit=True, button=1, 

515 minspanx=0, minspany=0, 

516 spancoords='pixels', 

517 props=dict(facecolor='gray', 

518 edgecolor='gray', 

519 alpha=0.2, 

520 fill=True), 

521 state_modifier_keys=dict(move='', 

522 clear='', 

523 square='', 

524 center='ctrl')) 

525 except TypeError: 

526 # old matplotlib: 

527 selector = widgets.RectangleSelector(ax, self._on_select, 

528 useblit=True, button=1) 

529 return selector 

530 

531 

532 def _plot_hist(self, ax, magnifiedax): 

533 """Plot and label a histogram.""" 

534 try: 

535 idx = self.hist_ax.index(ax) 

536 c = self.hist_indices[idx] 

537 in_hist = True 

538 except ValueError: 

539 idx = self.scatter_ax.index(ax) 

540 c = self.scatter_indices[idx][0] 

541 in_hist = False 

542 ax.clear() 

543 #ax.relim() 

544 #ax.autoscale(True) 

545 x = self.data[:,c] 

546 ax.hist(x, self.hist_nbins) 

547 #ax.autoscale(False) 

548 ax.set_xlabel(self.labels[c]) 

549 ax.xaxis.set_major_locator(plt.AutoLocator()) 

550 ax.xaxis.set_major_formatter(plt.ScalarFormatter()) 

551 if self.show_mode == 0: 

552 if self.categories[c] is not None: 

553 ax.set_xticks(np.arange(len(self.categories[c]))) 

554 ax.set_xticklabels(self.categories[c]) 

555 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x') 

556 if magnifiedax: 

557 ax.text(0.05, 0.9, f'n={len(self.data)}', 

558 transform=ax.transAxes, zorder=100) 

559 ax.set_ylabel('count') 

560 cax = self.hist_ax[self.scatter_indices[-1][0]] 

561 ax.set_xlim(cax.get_xlim()) 

562 else: 

563 if c == 0: 

564 ax.text(0.05, 0.9, f'n={len(self.data)}', 

565 transform=ax.transAxes, zorder=100) 

566 ax.set_ylabel('count') 

567 else: 

568 ax.yaxis.set_major_formatter(plt.NullFormatter()) 

569 selector = self._create_selector(ax) 

570 if in_hist: 

571 self.hist_selector[idx] = selector 

572 else: 

573 self.scatter_selector[idx] = selector 

574 self.scatter_artists[idx] = None 

575 ax.relim(True) 

576 if magnifiedax: 

577 self._add_backdrop(ax) 

578 

579 

580 def _set_hist_ylim(self): 

581 ymax = np.max([ax.get_ylim() for ax in self.hist_ax[:self.maxcols]], 0)[1] 

582 for ax in self.hist_ax: 

583 ax.set_ylim(0, ymax) 

584 

585 

586 def _init_hist_plots(self): 

587 """Initial plots of the histograms.""" 

588 n = self.data.shape[1] 

589 self.hist_ax = [] 

590 for r in range(n): 

591 ax = self.fig.add_subplot(n, n, (n-1)*n+r+1) 

592 self.hist_ax.append(ax) 

593 self.hist_indices.append(r) 

594 self.hist_selector.append(None) 

595 self._plot_hist(ax, False) 

596 self._set_hist_ylim() 

597 

598 

599 def _plot_scatter(self, ax, magnifiedax, cax=None): 

600 """Plot a scatter plot.""" 

601 idx = self.scatter_ax.index(ax) 

602 c, r = self.scatter_indices[idx] 

603 if self.scatter: # scatter plot 

604 ax.clear() 

605 a = ax.scatter(self.data[:,c], self.data[:,r], s=50, 

606 edgecolors='white', linewidths=0.5, 

607 picker=self.pick_radius, zorder=10) 

608 a.set_facecolor(self.data_colors) 

609 pr, pp = pearsonr(self.data[:,c], self.data[:,r]) 

610 fw = 'bold' if pp < self.significance_level/self.data.shape[1] else 'normal' 

611 if pr < 0: 

612 ax.text(0.95, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, 

613 transform=ax.transAxes, ha='right', zorder=100) 

614 else: 

615 ax.text(0.05, 0.9, f'r={pr:.2f}, p={pp:.3f}', fontweight=fw, 

616 transform=ax.transAxes, zorder=100) 

617 # color bar: 

618 if cax is not None: 

619 a = ax.scatter(self.data[:, c], self.data[:, r], 

620 c=self.color_values, cmap=self.color_map) 

621 self.fig.colorbar(a, cax=cax, ticks=self.color_ticks) 

622 a.remove() 

623 cax.set_ylabel(self.color_label) 

624 self.color_vmin, self.color_vmax, self.color_ticks = \ 

625 self.fix_scatter_plot(self.cbax, self.color_values, 

626 self.color_label, 'c') 

627 if self.color_ticks is None: 

628 if self.color_set_index == 0 and \ 

629 self.categories[self.color_index] is not None: 

630 cax.set_yticklabels(self.categories[self.color_index]) 

631 elif self.color_set_index == -1 and \ 

632 self.color_index == 1 and \ 

633 self.extra_categories is not None: 

634 cax.set_yticklabels(self.extra_categories) 

635 else: # histogram 

636 if self.show_mode == 0: 

637 self.fix_scatter_plot(ax, self.data[:,c], self.labels[c], 'x') 

638 self.fix_scatter_plot(ax, self.data[:,r], self.labels[r], 'y') 

639 axrange = [ax.get_xlim(), ax.get_ylim()] 

640 ax.clear() 

641 ax.hist2d(self.data[:,c], self.data[:,r], self.hist_nbins, 

642 range=axrange, cmap=plt.get_cmap('Greys')) 

643 # selected data: 

644 a = ax.scatter(self.data[self.mark_data, c], 

645 self.data[self.mark_data, r], s=100, 

646 edgecolors='black', linewidths=0.5, 

647 picker=self.pick_radius, zorder=11) 

648 a.set_facecolor(self.data_colors[self.mark_data]) 

649 self.scatter_artists[idx] = a 

650 ax.xaxis.set_major_locator(plt.AutoLocator()) 

651 ax.yaxis.set_major_locator(plt.AutoLocator()) 

652 ax.xaxis.set_major_formatter(plt.ScalarFormatter()) 

653 ax.yaxis.set_major_formatter(plt.ScalarFormatter()) 

654 if self.show_mode == 0: 

655 if self.categories[c] is not None: 

656 ax.set_xticks(np.arange(len(self.categories[c]))) 

657 ax.set_xticklabels(self.categories[c]) 

658 if self.categories[r] is not None: 

659 ax.set_yticks(np.arange(len(self.categories[r]))) 

660 ax.set_yticklabels(self.categories[r]) 

661 if magnifiedax: 

662 ax.set_xlabel(self.labels[c]) 

663 ax.set_ylabel(self.labels[r]) 

664 cax = self.scatter_ax[self.scatter_indices[:-1].index(self.scatter_indices[-1])] 

665 ax.set_xlim(cax.get_xlim()) 

666 ax.set_ylim(cax.get_ylim()) 

667 else: 

668 if c == 0: 

669 ax.set_ylabel(self.labels[r]) 

670 if self.show_mode == 0: 

671 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x') 

672 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y') 

673 if not magnifiedax: 

674 ax.xaxis.set_major_formatter(plt.NullFormatter()) 

675 if c > 0: 

676 ax.yaxis.set_major_formatter(plt.NullFormatter()) 

677 ax.set_xlim(*self.hist_ax[c].get_xlim()) 

678 ax.set_ylim(*self.hist_ax[r].get_xlim()) 

679 if magnifiedax: 

680 self._add_backdrop(ax) 

681 selector = self._create_selector(ax) 

682 self.scatter_selector[idx] = selector 

683 ax.relim(True) 

684 

685 

686 def _init_scatter_plots(self): 

687 """Initial plots of scatter plots.""" 

688 self.cbax = self.fig.add_axes([0.5, 0.5, 0.1, 0.5]) 

689 cbax = self.cbax 

690 n = self.data.shape[1] 

691 for r in range(1, n): 

692 for c in range(r): 

693 ax = self.fig.add_subplot(n, n, (r-1)*n+c+1) 

694 self.scatter_ax.append(ax) 

695 self.scatter_indices.append([c, r]) 

696 self.scatter_artists.append(None) 

697 self.scatter_selector.append(None) 

698 self._plot_scatter(ax, False, cbax) 

699 cbax = None 

700 

701 

702 def _plot_magnified_scatter(self): 

703 """Initial plot of the magnified scatter plot.""" 

704 ax = self.fig.add_axes([0.5, 0.9, 0.05, 0.05]) 

705 ax.set_visible(False) 

706 self.magnified_on = False 

707 c = 0 

708 r = 1 

709 a = ax.scatter(self.data[:, c], self.data[:, r], 

710 s=50, edgecolors='none') 

711 a.set_facecolor(self.data_colors) 

712 a = ax.scatter(self.data[self.mark_data, c], 

713 self.data[self.mark_data, r], s=80) 

714 a.set_facecolor(self.data_colors[self.mark_data]) 

715 ax.set_xlabel(self.labels[c]) 

716 ax.set_ylabel(self.labels[r]) 

717 self.fix_scatter_plot(ax, self.data[:, c], self.labels[c], 'x') 

718 self.fix_scatter_plot(ax, self.data[:, r], self.labels[r], 'y') 

719 self.scatter_ax.append(ax) 

720 self.scatter_indices.append([c, r]) 

721 self.scatter_artists.append(a) 

722 self.scatter_selector.append(None) 

723 

724 

725 def _plot_help(self): 

726 ax = self.fig.add_subplot(1, 1, 1) 

727 ax.set_position([0.02, 0.02, 0.96, 0.96]) 

728 ax.xaxis.set_major_locator(plt.NullLocator()) 

729 ax.yaxis.set_major_locator(plt.NullLocator()) 

730 n = len(self.mouse_actions) + len(self.key_actions) + 4 

731 dy = 1/n 

732 y = 1 - 1.5*dy 

733 ax.text(0.05, y, 'Key shortcuts', transform=ax.transAxes, 

734 fontweight='bold') 

735 y -= dy 

736 for a, d in self.key_actions: 

737 ax.text(0.05, y, a, transform=ax.transAxes) 

738 ax.text(0.3, y, d, transform=ax.transAxes) 

739 y -= dy 

740 y -= dy 

741 ax.text(0.05, y, 'Mouse actions', transform=ax.transAxes, 

742 fontweight='bold') 

743 y -= dy 

744 for a, d in self.mouse_actions: 

745 ax.text(0.05, y, a, transform=ax.transAxes) 

746 ax.text(0.3, y, d, transform=ax.transAxes) 

747 y -= dy 

748 ax.set_visible(False) 

749 self.help_ax = ax 

750 

751 

752 def fix_scatter_plot(self, ax, data, label, axis): 

753 """Customize an axes of a scatter plot. 

754 

755 This function is called after a scatter plot has been plotted. 

756 Once for the x axes, once for the y axis and once for the color bar. 

757 Reimplement this function to set appropriate limits and ticks. 

758 

759 Return values are only used for the color bar (`axis='c'`). 

760 Otherwise they are ignored. 

761 

762 For example, ticks for phase variables can be nicely labeled 

763 using the unicode character for pi: 

764 ``` 

765 if 'phase' in label: 

766 if axis == 'y': 

767 ax.set_ylim(0.0, 2.0*np.pi) 

768 ax.set_yticks(np.arange(0.0, 2.5*np.pi, 0.5*np.pi)) 

769 ax.set_yticklabels(['0', u'\u03c0/2', u'\u03c0', u'3\u03c0/2', u'2\u03c0']) 

770 ``` 

771  

772 Parameters 

773 ---------- 

774 ax: matplotlib axes 

775 Axes of the scatter plot or color bar to be worked on. 

776 data: 1D array 

777 Data array of the axes. 

778 label: str 

779 Label coresponding to the data array. 

780 axis: str 

781 'x', 'y': set properties of x or y axes of ax. 

782 'c': set properies of color bar axes (note that ax can be None!) 

783 and return vmin, vmax, and ticks. 

784 

785 Returns 

786 ------- 

787 min: float 

788 minimum value of color bar axis 

789 max: float 

790 maximum value of color bar axis 

791 ticks: list of float 

792 position of ticks for color bar axis 

793 """ 

794 return np.nanmin(data), np.nanmax(data), None 

795 

796 

797 def fix_waveform_plot(self, axs, indices): 

798 """Customize waveform plots. 

799 

800 This function is called once after new data have been plotted 

801 into the waveform plots. Reimplement this function to customize 

802 these plots. In particular to set axis limits and labels, plot 

803 title, etc. 

804 You may even open a new figure (with non-blocking `show()`). 

805 

806 The following member variables might be usefull: 

807 - `self.wave_data`: the full list of waveform data. 

808 - `self.wave_nested`: True if the elements of `self.wave_data` are lists of 2D arrays. Otherwise the elements are 2D arrays. The first column of a 2D array contains the x-values, further columns y-values. 

809 - `self.wave_has_xticks`: List of booleans for each axis. True if the axis has its own xticks. 

810 - `self.wave_xlabels`: List of xlabels (only for the axis where the corresponding entry in `self.wave_has_xticks` is True). 

811 - `self.wave_ylabels`: for each axis its ylabel 

812  

813 For example, you can set the linewidth of all plotted waveforms via: 

814 ``` 

815 for ax in axs: 

816 for l in ax.lines: 

817 l.set_linewidth(3.0) 

818 ``` 

819 or enable markers to be plotted: 

820 ``` 

821 for ax, yl in zip(axs, self.wave_ylabels): 

822 if 'Power' in yl: 

823 for l in ax.lines: 

824 l.set_marker('.') 

825 l.set_markersize(15.0) 

826 l.set_markeredgewidth(0.5) 

827 l.set_markeredgecolor('k') 

828 l.set_markerfacecolor(l.get_color()) 

829 ``` 

830 Usefull is to reduce the maximum number of y-ticks: 

831 ``` 

832 axs[0].yaxis.get_major_locator().set_params(nbins=7) 

833 ``` 

834 or 

835 ``` 

836 import matplotlib.ticker as ticker 

837 axs[0].yaxis.set_major_locator(ticker.MaxNLocator(nbins=4)) 

838 ``` 

839 

840 Parameters 

841 ---------- 

842 axs: list of matplotlib axes 

843 Axis of the waveform plots to be worked on. 

844 indices: list of int 

845 Indices of the waveforms that have been selected and plotted. 

846 """ 

847 pass 

848 

849 

850 def list_selection(self, indices): 

851 """List information about the current selection of data points. 

852 

853 This function is called when 'l' is pressed. Reimplement this 

854 function, for example, to print some meaningfull information 

855 about the current selection of data points on console. You may 

856 do, however, whatever you want in this function. 

857 

858 Parameters 

859 ---------- 

860 indices: list of int 

861 Indices of the data points that have been selected. 

862 """ 

863 print('') 

864 print('selected rows in data table:') 

865 for i in indices: 

866 print(self.valid_rows[i]) 

867 

868 

869 def analyze_selection(self, index): 

870 """Provide further information about a single selected data point. 

871 

872 This function is called when a single data item was double 

873 clicked. Reimplement this function to provide some further 

874 details on this data point. This can be an additional figure 

875 window. In this case show it non-blocking: 

876 `plt.show(block=False)` 

877 

878 Parameters 

879 ---------- 

880 index: int 

881 The index of the selected data point. 

882 """ 

883 pass 

884 

885 

886 def _set_magnified_pos(self, width, height): 

887 """Set position of magnified plot.""" 

888 if self.magnified_on: 

889 xoffs = self.xborder/width 

890 yoffs = self.yborder/height 

891 if self.scatter_indices[-1][1] < self.data.shape[1]: 

892 idx = self.scatter_indices[:-1].index(self.scatter_indices[-1]) 

893 pos = self.scatter_ax[idx].get_position().get_points() 

894 else: 

895 pos = self.hist_ax[self.scatter_indices[-1][0]].get_position().get_points() 

896 pos[0] = np.mean(pos, 0) - 0.5*self.magnified_size 

897 if pos[0][0] < xoffs: pos[0][0] = xoffs 

898 if pos[0][1] < yoffs: pos[0][1] = yoffs 

899 pos[1] = pos[0] + self.magnified_size 

900 if pos[1][0] > 1.0-self.spacing/width: pos[1][0] = 1.0-self.spacing/width 

901 if pos[1][1] > 1.0-self.spacing/height: pos[1][1] = 1.0-self.spacing/height 

902 pos[0] = pos[1] - self.magnified_size 

903 self.scatter_ax[-1].set_position([pos[0][0], pos[0][1], 

904 self.magnified_size[0], self.magnified_size[1]]) 

905 self.scatter_ax[-1].set_visible(True) 

906 else: 

907 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

908 self.scatter_ax[-1].set_visible(False) 

909 

910 

911 def _make_selection(self, ax, key, x0, x1, y0, y1): 

912 """Select points from a scatter or histogram plot.""" 

913 if not key in ['shift', 'control']: 

914 self.mark_data = [] 

915 if ax in self.scatter_ax: 

916 axi = self.scatter_ax.index(ax) 

917 # from scatter plots: 

918 c, r = self.scatter_indices[axi] 

919 if r < self.data.shape[1]: 

920 # from scatter: 

921 for ind, (x, y) in enumerate(zip(self.data[:, c], self.data[:, r])): 

922 if x >= x0 and x <= x1 and y >= y0 and y <= y1: 

923 if ind in self.mark_data: 

924 if key == 'control': 

925 self.mark_data.remove(ind) 

926 elif key != 'control': 

927 self.mark_data.append(ind) 

928 else: 

929 # from histogram: 

930 for ind, x in enumerate(self.data[:, c]): 

931 if x >= x0 and x <= x1: 

932 if ind in self.mark_data: 

933 if key == 'control': 

934 self.mark_data.remove(ind) 

935 elif key != 'control': 

936 self.mark_data.append(ind) 

937 elif ax in self.hist_ax: 

938 r = self.hist_indices[self.hist_ax.index(ax)] 

939 # from histogram: 

940 for ind, x in enumerate(self.data[:, r]): 

941 if x >= x0 and x <= x1: 

942 if ind in self.mark_data: 

943 if key == 'control': 

944 self.mark_data.remove(ind) 

945 elif key != 'control': 

946 self.mark_data.append(ind) 

947 

948 

949 def _update_selection(self): 

950 """Highlight selected points in the scatter plots and plot corresponding waveforms.""" 

951 # update scatter plots: 

952 for artist, (c, r) in zip(self.scatter_artists, self.scatter_indices): 

953 if artist is not None: 

954 if len(self.mark_data) == 0: 

955 artist.set_offsets(np.zeros((0, 2))) 

956 else: 

957 artist.set_offsets(list(zip(self.data[self.mark_data, c], 

958 self.data[self.mark_data, r]))) 

959 artist.set_facecolors(self.data_colors[self.mark_data]) 

960 # waveform plots: 

961 if len(self.wave_ax) > 0: 

962 axdi = 0 

963 axti = 1 

964 for xi, ax in enumerate(self.wave_ax): 

965 ax.clear() 

966 if len(self.mark_data) > 0: 

967 for idx in self.mark_data: 

968 if self.wave_nested: 

969 data = self.wave_data[idx][axdi] 

970 else: 

971 data = self.wave_data[idx] 

972 if data is not None: 

973 ax.plot(data[:, 0], data[:, axti], 

974 c=self.data_colors[idx], 

975 picker=self.pick_radius) 

976 axti += 1 

977 if self.wave_has_xticks[xi]: 

978 ax.set_xlabel(self.wave_xlabels[axdi]) 

979 axti = 1 

980 axdi += 1 

981 #else: 

982 # ax.xaxis.set_major_formatter(plt.NullFormatter()) 

983 for ax, ylabel in zip(self.wave_ax, self.wave_ylabels): 

984 ax.set_ylabel(ylabel) 

985 if not isinstance(self.wave_title, bool) and self.wave_title: 

986 self.wave_ax[0].set_title(self.wave_title) 

987 self.fix_waveform_plot(self.wave_ax, self.mark_data) 

988 self.fig.canvas.draw() 

989 

990 

991 def _set_limits(self, ax, x0, x1, y0, y1): 

992 if ax in self.hist_ax: 

993 ax.set_xlim(x0, x1) 

994 for hax in self.hist_ax: 

995 hax.set_ylim(y0, y1) 

996 cc = self.hist_indices[self.hist_ax.index(ax)] 

997 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): 

998 if c == cc: 

999 sax.set_xlim(x0, x1) 

1000 if r == cc: 

1001 sax.set_ylim(x0, x1) 

1002 if ax in self.scatter_ax: 

1003 idx = self.scatter_ax.index(ax) 

1004 cc, rr = self.scatter_indices[idx] 

1005 self.hist_ax[cc].set_xlim(x0, x1) 

1006 self.hist_ax[rr].set_xlim(y0, y1) 

1007 for sax, (c, r) in zip(self.scatter_ax, self.scatter_indices): 

1008 if c == cc: 

1009 sax.set_xlim(x0, x1) 

1010 if c == rr: 

1011 sax.set_xlim(y0, y1) 

1012 if r == cc: 

1013 sax.set_ylim(x0, x1) 

1014 if r == rr: 

1015 sax.set_ylim(y0, y1) 

1016 

1017 

1018 def _on_key(self, event): 

1019 """Handle key events.""" 

1020 #print('pressed', event.key) 

1021 if event.key in ['left', 'right', 'up', 'down']: 

1022 if self.magnified_on: 

1023 mc, mr = self.scatter_indices[-1] 

1024 if event.key == 'left': 

1025 if mc > 0: 

1026 self.scatter_indices[-1][0] -= 1 

1027 elif mr > 1: 

1028 if mr >= self.data.shape[1]: 

1029 self.scatter_indices[-1][1] = self.maxcols - 1 

1030 else: 

1031 self.scatter_indices[-1][1] -= 1 

1032 self.scatter_indices[-1][0] = self.scatter_indices[-1][1] - 1 

1033 else: 

1034 self.scatter_indices[-1][0] = self.data.shape[1] - 1 

1035 self.scatter_indices[-1][1] = self.data.shape[1] 

1036 elif event.key == 'right': 

1037 if mc < mr - 1 and mc < self.maxcols - 1: 

1038 self.scatter_indices[-1][0] += 1 

1039 elif mr < self.maxcols: 

1040 self.scatter_indices[-1][0] = 0 

1041 self.scatter_indices[-1][1] += 1 

1042 if mr >= self.maxcols: 

1043 self.scatter_indices[-1][1] = self.data.shape[1] 

1044 else: 

1045 self.scatter_indices[-1][0] = 0 

1046 self.scatter_indices[-1][1] = 1 

1047 elif event.key == 'up': 

1048 if mr > mc + 1: 

1049 if mr >= self.data.shape[1]: 

1050 self.scatter_indices[-1][1] = self.maxcols - 1 

1051 else: 

1052 self.scatter_indices[-1][1] -= 1 

1053 elif mc > 0: 

1054 self.scatter_indices[-1][0] -= 1 

1055 self.scatter_indices[-1][1] = self.data.shape[1] 

1056 else: 

1057 self.scatter_indices[-1][0] = self.data.shape[1] - 1 

1058 self.scatter_indices[-1][1] = self.data.shape[1] 

1059 elif event.key == 'down': 

1060 if mr < self.maxcols: 

1061 self.scatter_indices[-1][1] += 1 

1062 if mr >= self.maxcols: 

1063 self.scatter_indices[-1][1] = self.data.shape[1] 

1064 elif mc < self.maxcols - 1: 

1065 self.scatter_indices[-1][0] += 1 

1066 self.scatter_indices[-1][1] = mc + 2 

1067 if self.scatter_indices[-1][1] >= self.maxcols: 

1068 self.scatter_indices[-1][1] = self.data.shape[1] 

1069 else: 

1070 self.scatter_indices[-1][0] = 0 

1071 self.scatter_indices[-1][1] = 1 

1072 for k in reversed(range(len(self.zoom_stack))): 

1073 if self.zoom_stack[k][0] == self.scatter_ax[-1]: 

1074 del self.zoom_stack[k] 

1075 self.scatter_ax[-1].clear() 

1076 self.scatter_ax[-1].set_visible(True) 

1077 self.magnified_on = True 

1078 self._set_magnified_pos(self.fig.get_window_extent().width, 

1079 self.fig.get_window_extent().height) 

1080 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1081 self._plot_scatter(self.scatter_ax[-1], True) 

1082 else: 

1083 self._plot_hist(self.scatter_ax[-1], True) 

1084 self.fig.canvas.draw() 

1085 else: 

1086 if event.key == 'escape': 

1087 if len(self.scatter_ax) >= self.data.shape[1]: 

1088 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1089 self.magnified_on = False 

1090 self.scatter_ax[-1].set_visible(False) 

1091 self.fig.canvas.draw() 

1092 elif event.key in 'oz': 

1093 self.select_zooms = not self.select_zooms 

1094 elif event.key == 'backspace': 

1095 if len(self.zoom_stack) > 0: 

1096 self._set_limits(*self.zoom_stack.pop()) 

1097 self.fig.canvas.draw() 

1098 elif event.key in '+=': 

1099 self.pick_radius *= 1.5 

1100 elif event.key in '-': 

1101 if self.pick_radius > 5.0: 

1102 self.pick_radius /= 1.5 

1103 elif event.key in '0': 

1104 self.pick_radius = 4.0 

1105 elif event.key in ['pageup', 'pagedown', '<', '>']: 

1106 if event.key in ['pageup', '<'] and self.maxcols > 2: 

1107 self.maxcols -= 1 

1108 elif event.key in ['pagedown', '>'] and self.maxcols < self.raw_data.shape[1]: 

1109 self.maxcols += 1 

1110 for ax in self.hist_ax: 

1111 self._plot_hist(ax, False) 

1112 self._update_layout() 

1113 elif event.key == 'w': 

1114 if len(self.wave_data) > 0: 

1115 if self.maxcols > 0: 

1116 self.all_maxcols[self.show_mode] = self.maxcols 

1117 self.maxcols = 0 

1118 else: 

1119 self.maxcols = self.all_maxcols[self.show_mode] 

1120 self._set_layout(self.fig.get_window_extent().width, 

1121 self.fig.get_window_extent().height) 

1122 self.fig.canvas.draw() 

1123 elif event.key == 'ctrl+a': 

1124 self.mark_data = list(range(len(self.data))) 

1125 self._update_selection() 

1126 elif event.key in 'cC': 

1127 if event.key in 'c': 

1128 self.color_index -= 1 

1129 if self.color_index < 0: 

1130 self.color_set_index -= 1 

1131 if self.color_set_index < -1: 

1132 self.color_set_index = len(self.all_data)-1 

1133 if self.color_set_index >= 0: 

1134 if self.all_data[self.color_set_index] is None: 

1135 self.compute_pca(self.color_set_index>1, True) 

1136 self.color_index = self.all_data[self.color_set_index].shape[1]-1 

1137 else: 

1138 self.color_index = 0 if self.extra_colors is None else 1 

1139 self._set_color_column() 

1140 else: 

1141 self.color_index += 1 

1142 if (self.color_set_index >= 0 and \ 

1143 self.color_index >= self.all_data[self.color_set_index].shape[1]) or \ 

1144 (self.color_set_index < 0 and \ 

1145 self.color_index >= (1 if self.extra_colors is None else 2)): 

1146 self.color_index = 0 

1147 self.color_set_index += 1 

1148 if self.color_set_index >= len(self.all_data): 

1149 self.color_set_index = -1 

1150 elif self.all_data[self.color_set_index] is None: 

1151 self.compute_pca(self.color_set_index>1, True) 

1152 self._set_color_column() 

1153 for ax in self.scatter_ax: 

1154 ax.collections[0].set_facecolors(self.data_colors) 

1155 for a in self.scatter_artists: 

1156 if a is not None: 

1157 a.set_facecolors(self.data_colors[self.mark_data]) 

1158 for ax in self.wave_ax: 

1159 for l, c in zip(ax.lines, self.data_colors[self.mark_data]): 

1160 l.set_color(c) 

1161 l.set_markerfacecolor(c) 

1162 self._plot_scatter(self.scatter_ax[0], False, self.cbax) 

1163 self.fix_scatter_plot(self.cbax, self.color_values, 

1164 self.color_label, 'c') 

1165 self.fig.canvas.draw() 

1166 elif event.key in 'nN': 

1167 if event.key in 'N': 

1168 self.hist_nbins = (self.hist_nbins*3)//2 

1169 elif self.hist_nbins >= 15: 

1170 self.hist_nbins = (self.hist_nbins*2)//3 

1171 else: 

1172 self.hist_nbins = 10 

1173 for ax in self.hist_ax: 

1174 self._plot_hist(ax, False) 

1175 self._set_hist_ylim() 

1176 if self.scatter_indices[-1][1] >= self.data.shape[1]: 

1177 self._plot_hist(self.scatter_ax[-1], True, True) 

1178 elif not self.scatter: 

1179 self._plot_scatter(self.scatter_ax[-1], True) 

1180 if not self.scatter: 

1181 for ax in self.scatter_ax[:-1]: 

1182 self._plot_scatter(ax, False) 

1183 self.fig.canvas.draw() 

1184 elif event.key in 'H': 

1185 self.scatter = not self.scatter 

1186 for ax in self.scatter_ax[:-1]: 

1187 self._plot_scatter(ax, False) 

1188 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1189 self._plot_scatter(self.scatter_ax[-1], True) 

1190 self.fig.canvas.draw() 

1191 elif event.key in 'pP': 

1192 if len(self.scatter_ax) >= self.data.shape[1]: 

1193 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1194 self.scatter_indices[-1] = [0, 1] 

1195 self.magnified_on = False 

1196 self.scatter_ax[-1].set_visible(False) 

1197 self.all_maxcols[self.show_mode] = self.maxcols 

1198 if event.key == 'P': 

1199 self.show_mode += 1 

1200 if self.show_mode >= len(self.all_data): 

1201 self.show_mode = 0 

1202 else: 

1203 self.show_mode -= 1 

1204 if self.show_mode < 0: 

1205 self.show_mode = len(self.all_data)-1 

1206 if self.show_mode == 1: 

1207 print('principal components') 

1208 elif self.show_mode == 2: 

1209 print('scaled principal components') 

1210 else: 

1211 print('data') 

1212 if self.all_data[self.show_mode] is None: 

1213 self.compute_pca(self.show_mode>1, True) 

1214 self.data = self.all_data[self.show_mode] 

1215 self.labels = self.all_labels[self.show_mode] 

1216 self.maxcols = self.all_maxcols[self.show_mode] 

1217 self.zoom_stack = [] 

1218 self.fig.canvas.manager.set_window_title(self.title + ': ' + self.all_titles[self.show_mode]) 

1219 for ax in self.hist_ax: 

1220 self._plot_hist(ax, False) 

1221 self._set_hist_ylim() 

1222 for ax in self.scatter_ax: 

1223 self._plot_scatter(ax, False) 

1224 self._update_layout() 

1225 elif event.key in 'l': 

1226 if len(self.mark_data) > 0: 

1227 self.list_selection(self.mark_data) 

1228 elif event.key in 'h': 

1229 self.help_ax.set_visible(not self.help_ax.get_visible()) 

1230 self.fig.canvas.draw() 

1231 

1232 

1233 def _on_select(self, eclick, erelease): 

1234 """Handle selection events.""" 

1235 if eclick.dblclick: 

1236 if len(self.mark_data) > 0: 

1237 self.analyze_selection(self.mark_data[-1]) 

1238 return 

1239 x0 = min(eclick.xdata, erelease.xdata) 

1240 x1 = max(eclick.xdata, erelease.xdata) 

1241 y0 = min(eclick.ydata, erelease.ydata) 

1242 y1 = max(eclick.ydata, erelease.ydata) 

1243 ax = erelease.inaxes 

1244 if ax is None: 

1245 ax = eclick.inaxes 

1246 xmin, xmax = ax.get_xlim() 

1247 ymin, ymax = ax.get_ylim() 

1248 dx = 0.02*(xmax-xmin) 

1249 dy = 0.02*(ymax-ymin) 

1250 if x1 - x0 < dx and y1 - y0 < dy: 

1251 bbox = ax.get_window_extent().transformed(self.fig.dpi_scale_trans.inverted()) 

1252 width, height = bbox.width, bbox.height 

1253 width *= self.fig.dpi 

1254 height *= self.fig.dpi 

1255 dx = self.pick_radius*(xmax-xmin)/width 

1256 dy = self.pick_radius*(ymax-ymin)/height 

1257 x0 = erelease.xdata - dx 

1258 x1 = erelease.xdata + dx 

1259 y0 = erelease.ydata - dy 

1260 y1 = erelease.ydata + dy 

1261 elif self.select_zooms: 

1262 self.zoom_stack.append((ax, xmin, xmax, ymin, ymax)) 

1263 self._set_limits(ax, x0, x1, y0, y1) 

1264 self._make_selection(ax, erelease.key, x0, x1, y0, y1) 

1265 self._update_selection() 

1266 

1267 

1268 def _on_pick(self, event): 

1269 """Handle pick events.""" 

1270 for ax in self.wave_ax: 

1271 for k, l in enumerate(ax.lines): 

1272 if l is event.artist: 

1273 self.mark_data = [self.mark_data[k]] 

1274 for ax in self.scatter_ax: 

1275 if ax.collections[0] is event.artist: 

1276 self.mark_data = event.ind 

1277 self._update_selection() 

1278 if event.mouseevent.dblclick: 

1279 if len(self.mark_data) > 0: 

1280 self.analyze_selection(self.mark_data[-1]) 

1281 

1282 

1283 def _set_layout(self, width, height): 

1284 """Update positions and visibility of all plots.""" 

1285 xoffs = self.xborder/width 

1286 yoffs = self.yborder/height 

1287 xs = self.spacing/width 

1288 ys = self.spacing/height 

1289 if self.maxcols > 0: 

1290 dx = (1.0-xoffs)/self.maxcols 

1291 dy = (1.0-yoffs)/self.maxcols 

1292 xw = dx - xs 

1293 yw = dy - ys 

1294 # histograms: 

1295 for c, ax in enumerate(self.hist_ax): 

1296 if c < self.maxcols: 

1297 ax.set_position([xoffs+c*dx, yoffs, xw, yw]) 

1298 ax.set_visible(True) 

1299 else: 

1300 ax.set_visible(False) 

1301 ax.set_position([0.99, 0.01, 0.01, 0.01]) 

1302 # scatter plots: 

1303 for ax, (c, r) in zip(self.scatter_ax[:-1], self.scatter_indices[:-1]): 

1304 if r < self.maxcols: 

1305 ax.set_position([xoffs+c*dx, yoffs+(self.maxcols-r)*dy, xw, yw]) 

1306 ax.set_visible(True) 

1307 else: 

1308 ax.set_visible(False) 

1309 ax.set_position([0.99, 0.01, 0.01, 0.01]) 

1310 # color bar: 

1311 if self.maxcols > 0: 

1312 self.cbax.set_position([xoffs+dx, yoffs+(self.maxcols-1)*dy, 0.3*xoffs, yw]) 

1313 self.cbax.set_visible(True) 

1314 else: 

1315 self.cbax.set_visible(False) 

1316 self.cbax.set_position([0.99, 0.01, 0.01, 0.01]) 

1317 # magnified plot: 

1318 if self.maxcols > 0: 

1319 self._set_magnified_pos(width, height) 

1320 if self.magnified_backdrop is not None: 

1321 bbox = self.scatter_ax[-1].get_tightbbox(self.fig.canvas.get_renderer()) 

1322 if bbox is not None: 

1323 self.magnified_backdrop.set_bounds(bbox.x0 - self.mborder, 

1324 bbox.y0 - self.mborder, 

1325 bbox.width + 2*self.mborder, 

1326 bbox.height + 2*self.mborder) 

1327 else: 

1328 self.scatter_ax[-1].set_position([0.5, 0.9, 0.05, 0.05]) 

1329 self.scatter_ax[-1].set_visible(False) 

1330 # waveform plots: 

1331 if len(self.wave_ax) > 0: 

1332 if self.maxcols > 0: 

1333 x0 = xoffs+((self.maxcols+1)//2)*dx 

1334 y0 = ((self.maxcols+1)//2)*dy 

1335 if self.maxcols%2 == 0: 

1336 x0 += xoffs 

1337 y0 += yoffs - ys 

1338 else: 

1339 y0 += ys 

1340 else: 

1341 x0 = xoffs 

1342 y0 = 0.0 

1343 yp = 1.0 

1344 dy = 1.0-y0 

1345 dy -= np.sum(self.wave_has_xticks)*yoffs 

1346 yp -= ys 

1347 dy -= ys 

1348 if self.wave_title: 

1349 yp -= 2*ys 

1350 dy -= 2*ys 

1351 dy /= len(self.wave_ax) 

1352 for ax, has_xticks in zip(self.wave_ax, self.wave_has_xticks): 

1353 yp -= dy 

1354 ax.set_position([x0, yp, 1.0-x0-xs, dy]) 

1355 if has_xticks: 

1356 yp -= yoffs 

1357 else: 

1358 yp -= ys 

1359 

1360 

1361 def _update_layout(self): 

1362 """Update content and position of magnified plot.""" 

1363 if self.scatter_indices[-1][1] < self.data.shape[1]: 

1364 if self.scatter_indices[-1][1] >= self.maxcols: 

1365 self.scatter_indices[-1][1] = self.maxcols-1 

1366 if self.scatter_indices[-1][0] >= self.scatter_indices[-1][1]: 

1367 self.scatter_indices[-1][0] = self.scatter_indices[-1][1]-1 

1368 self._plot_scatter(self.scatter_ax[-1], True) 

1369 else: 

1370 if self.scatter_indices[-1][0] >= self.maxcols: 

1371 self.scatter_indices[-1][0] = self.maxcols-1 

1372 self._plot_hist(self.scatter_ax[-1], True) 

1373 self._set_hist_ylim() 

1374 self._set_layout(self.fig.get_window_extent().width, 

1375 self.fig.get_window_extent().height) 

1376 self.fig.canvas.draw() 

1377 

1378 

1379 def _on_resize(self, event): 

1380 """Adapt layout of plots to new figure size.""" 

1381 self._set_layout(event.width, event.height) 

1382 

1383 

1384def categorize(data): 

1385 """Convert categorial string data into integer categories. 

1386 

1387 Parameters 

1388 ---------- 

1389 data: list or ndarray of str 

1390 Data with textual categories. 

1391 

1392 Returns 

1393 ------- 

1394 categories: list of str 

1395 A sorted unique list of the strings in `data`, 

1396 that is the names of the categories. 

1397 cdata: ndarray of int 

1398 A copy of the input `data` where each string value is replaced 

1399 by an integer number that is an index into the returned `categories`. 

1400 """ 

1401 cats = sorted(set(data)) 

1402 cdata = np.array([cats.index(x) for x in data], dtype=int) 

1403 return cats, cdata 

1404 

1405 

1406def select_features(data, columns): 

1407 """Assemble list of column indices. 

1408 

1409 Parameters 

1410 ---------- 

1411 data: TableData 

1412 The table from which to select features. 

1413 columns: list of str 

1414 Feature names (column headers) to be selected from the data. 

1415 If a column is listed twice (even times) it is not added. 

1416 

1417 Returns 

1418 ------- 

1419 data_cols: list of int 

1420 List of indices into data columns for selecting features. 

1421 """ 

1422 if len(columns) == 0: 

1423 data_cols = list(np.arange(len(data))) 

1424 else: 

1425 data_cols = [] 

1426 for c in columns: 

1427 idx = data.index(c) 

1428 if idx is None: 

1429 print(f'"{c}" is not a valid data column') 

1430 elif idx in data_cols: 

1431 data_cols.remove(idx) 

1432 else: 

1433 data_cols.append(idx) 

1434 return data_cols 

1435 

1436 

1437def select_coloring(data, data_cols, color_col): 

1438 """Select column from data table for colorizing the data. 

1439 

1440 Pass the output of this function on to MultivariateExplorer.set_colors(). 

1441 

1442 Parameters 

1443 ---------- 

1444 data: TableData 

1445 Table with all EOD properties from which columns are selected. 

1446 data_cols: list of int 

1447 List of columns selected to be explored. 

1448 color_col: str or int 

1449 Column to be selected for coloring the data. 

1450 If 'row' then use the row index of the data in the table for coloring. 

1451 

1452 Returns 

1453 ------- 

1454 colors: int or list of values or None 

1455 Either index of `data_cols` or additional data from the data table 

1456 to be used for coloring. 

1457 color_label: str or None 

1458 Label for labeling the color bar. 

1459 color_idx: int or None 

1460 Index of color column in `data`. 

1461 error: None or str 

1462 In case an invalid column is selected, an error string. 

1463 """ 

1464 color_idx = data.index(color_col) 

1465 colors = None 

1466 color_label = None 

1467 if color_idx is None and color_col != 'row': 

1468 return None, None, None, f'"{color_col}" is not a valid column for color code' 

1469 if color_idx is None: 

1470 colors = -2 

1471 elif color_idx in data_cols: 

1472 colors = data_cols.index(color_idx) 

1473 else: 

1474 if len(data.unit(color_idx)) > 0 and not data.unit(color_idx) in ['-', '1']: 

1475 color_label = f'{data.label(color_idx)} [{data.unit(color_idx)}]' 

1476 else: 

1477 color_label = data.label(color_idx) 

1478 colors = data[:, color_idx] 

1479 return colors, color_label, color_idx, None 

1480 

1481 

1482def list_available_features(data, data_cols=[], color_col=None): 

1483 """Print available features on console. 

1484 

1485 Parameters 

1486 ---------- 

1487 data: TableData 

1488 The full data table. 

1489 data_cols: list of int 

1490 List of indices of columns (features) in the table 

1491 that are passed on to the MultivariateExplorer. 

1492 color_col: int or None 

1493 Index of columns (feature) in the table 

1494 that is initially used for color coding the data. 

1495 """ 

1496 print('available features:') 

1497 for k, c in enumerate(data.keys()): 

1498 s = [' '] * 3 

1499 if k in data_cols: 

1500 s[1] = '*' 

1501 if color_col is not None and k == color_col: 

1502 s[0] = 'C' 

1503 print(''.join(s) + c) 

1504 if len(data_cols) > 0: 

1505 print('*: feature selected for exploration') 

1506 if color_col is not None: 

1507 print('C: feature selected for color coding the data') 

1508 

1509 

1510class PrintHelp(argparse.Action): 

1511 def __call__(self, parser, namespace, values, option_string): 

1512 parser.print_help() 

1513 print('') 

1514 print('mouse:') 

1515 for ma in MultivariateExplorer.mouse_actions: 

1516 print('%-23s %s' % ma) 

1517 print('%-23s %s' % ('double left click', 'run thunderfish on selected EOD waveform')) 

1518 print('') 

1519 print('key shortcuts:') 

1520 for ka in MultivariateExplorer.key_actions: 

1521 print('%-23s %s' % ka) 

1522 parser.exit() 

1523 

1524 

1525def demo(): 

1526 """Run the multivariate explorer with a random test data set. 

1527 """ 

1528 # generate data: 

1529 n = 100 

1530 data = [] 

1531 data.append(np.random.randn(n) + 2.0) 

1532 data.append(1.0+0.1*data[0] + 1.5*np.random.randn(n)) 

1533 data.append(10*(-3.0*data[0] + 2.0*data[1] + 1.8*np.random.randn(n))) 

1534 idx = np.random.randint(0, 3, n) 

1535 names = ['aaa', 'bbb', 'ccc'] 

1536 data.append([names[i] for i in idx]) 

1537 # generate waveforms: 

1538 waveforms = [] 

1539 time = np.arange(0.0, 10.0, 0.01) 

1540 for r in range(len(data[0])): 

1541 x = data[0][r]*np.sin(2.0*np.pi*data[1][r]*time + data[2][r]) 

1542 y = data[0][r]*np.exp(-0.5*((time-data[1][r])/(0.2*data[2][r]))**2.0) 

1543 waveforms.append(np.column_stack((time, x, y))) 

1544 #waveforms.append([np.column_stack((time, x)), np.column_stack((time, y))]) 

1545 # initialize explorer: 

1546 expl = MultivariateExplorer(data, 

1547 list(map(chr, np.arange(len(data))+ord('A'))), 

1548 'Explorer') 

1549 expl.set_wave_data(waveforms, 'Time', ['Sine', 'Gauss']) 

1550 # explore data: 

1551 expl.set_colors() 

1552 expl.show() 

1553 

1554 

1555def main(*cargs): 

1556 # parse command line: 

1557 parser = argparse.ArgumentParser(add_help=False, 

1558 description='View and explore multivariate data.', 

1559 epilog = f'version {__version__} by Benda-Lab (2019-{__year__})') 

1560 parser.add_argument('-h', '--help', nargs=0, action=PrintHelp, 

1561 help='show this help message and exit') 

1562 parser.add_argument('--version', action='version', version=__version__) 

1563 parser.add_argument('-l', dest='list_features', action='store_true', 

1564 help='list all available data columns (features) and exit') 

1565 parser.add_argument('-d', dest='data_cols', action='append', 

1566 default=[], metavar='COLUMN', 

1567 help='data columns (features) to be explored') 

1568 parser.add_argument('-c', dest='color_col', default=None, 

1569 type=str, metavar='COLUMN', 

1570 help='data column to be used for color code or "row"') 

1571 parser.add_argument('-m', dest='color_map', default='jet', 

1572 type=str, metavar='CMAP', 

1573 help='name of color map to be used') 

1574 parser.add_argument('file', nargs='?', default='', type=str, 

1575 help='a file containing a table of data (csv file or similar)') 

1576 if len(cargs) == 0: 

1577 cargs = None 

1578 args = parser.parse_args(cargs) 

1579 if args.file: 

1580 # load data: 

1581 data = TableData(args.file) 

1582 data_cols = select_features(data, args.data_cols) 

1583 # select column used for coloring the data: 

1584 colors, color_label, color_col, error = \ 

1585 select_coloring(data, data_cols, args.color_col) 

1586 if error: 

1587 parser.error(error) 

1588 # list features: 

1589 if args.list_features: 

1590 list_available_features(data, data_cols, color_col) 

1591 parser.exit() 

1592 # explore data: 

1593 expl = MultivariateExplorer(data[:, data_cols]) 

1594 expl.set_colors(colors, color_label, args.color_map) 

1595 expl.show() 

1596 else: 

1597 demo() 

1598 

1599 

1600if __name__ == '__main__': 

1601 main(*sys.argv[1:])