Coverage for src/audioio/audiotools.py: 95%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 07:29 +0000

1"""Tools for fixing audio data. 

2 

3- `despike()`: remove spikes. 

4- `unwrap()`: unwrap clipped data that are folded into the available data range. 

5""" 

6 

7import warnings 

8import numpy as np 

9 

10has_numba = False 

11try: 

12 from numba import jit, prange 

13 has_numba = True 

14except ImportError: 

15 def jit(*args, **kwargs): 

16 def decorator_jit(func): 

17 return func 

18 return decorator_jit 

19 prange = range 

20 

21 

22def despike(data, thresh=1.0, n=1): 

23 """Remove spikes.  

24 

25 If `n` data points stick out by more than a threshold, they are 

26 replaced by the mean of the two directly preceeding and succeeding 

27 data points. 

28 

29 Parameters 

30 ---------- 

31 data: 1D or 2D ndarray 

32 Data to be fixed in place. 

33 thresh: float 

34 Threshold defining a spike. 

35 n: int 

36 Maximum width of spike. 

37 """ 

38 @jit(nopython=True) 

39 def despike_trace(data, thresh, n): 

40 for k in range(n, 0, -1): 

41 for i in range(1, len(data)-k): 

42 if (data[i] - data[i-1] > thresh and \ 

43 data[i+k-1] - data[i+k] > thresh) or \ 

44 (data[i-1] - data[i] > thresh and \ 

45 data[i+k-1] - data[i+k] > thresh): 

46 for j in range(k): 

47 data[i+j] = ((k-j)*data[i-1] + (1+j)*data[i+k])/(k+1) 

48 

49 @jit(nopython=True, parallel=True) 

50 def despike_traces(data, thresh, n): 

51 for c in prange(data.shape[1]): 

52 despike_trace(data[:,c], thresh, n) 

53 

54 if data.ndim > 1: 

55 if has_numba and data.shape[1] > 1: 

56 despike_traces(data, thresh, n) 

57 else: 

58 for c in range(data.shape[1]): 

59 despike(data[:,c], thresh, n) 

60 else: 

61 # not faster:  

62 #if has_numba: 

63 # despike_trace(data, thresh, n) 

64 #else: 

65 for k in range(n, 0, -1): 

66 # find k-spikes: 

67 diff = np.diff(data) 

68 sel = ((diff[:-k] > thresh) & (diff[k:] < -thresh)) | \ 

69 ((diff[:-k] < -thresh) & (diff[k:] > thresh)) 

70 # replace with weighted average of neighbors: 

71 for j in range(1, k+1): 

72 data[j:-k-1+j][sel] = ((k+1-j)*data[:-1-k][sel] + \ 

73 j*data[1+k:][sel])/(k+1) 

74 

75 

76def unwrap(data, thresh=1.5, ampl_max=1.0): 

77 """Unwrap clipped data that are folded into the available data range. 

78 

79 In some amplifiers/ADCs clipped data appear on the opposite side 

80 of the input range. This function tries to undo this wrapping. 

81  

82 Parameters 

83 ---------- 

84 data: 1D or 2D ndarray of floats 

85 Data to be fixed in place. 

86 thresh: float 

87 Minimum difference between succeeding data points required 

88 for initiating unwrapping relative to ampl_max. 

89 ampl_max: float 

90 Maximum amplitude of the input range. 

91 """ 

92 

93 @jit(nopython=True) 

94 def unwrap_trace(data, thresh, ampl_max): 

95 step = 0.0 

96 for i in range(1, len(data)): 

97 cstep = 0.0 

98 dd = data[i] - data[i-1] 

99 if data[i] >= 0: 

100 if abs(dd - 2.0*ampl_max) < abs(dd): 

101 cstep = -2.0*ampl_max 

102 if data[i] <= 0: 

103 if abs(dd + 2.0*ampl_max) < abs(dd + cstep): 

104 cstep = +2.0*ampl_max 

105 if step != cstep and (cstep == 0.0 or abs(dd) > thresh): 

106 step = cstep 

107 data[i] += step 

108 

109 @jit(nopython=True, parallel=True) 

110 def unwrap_traces(data, thresh, ampl_max): 

111 for c in prange(data.shape[1]): 

112 unwrap_trace(data[:,c], thresh, ampl_max) 

113 

114 if not has_numba: 

115 warnings.warn('unwrap() requires numba to work') 

116 thresh *= ampl_max 

117 if data.ndim > 1: 

118 unwrap_traces(data, thresh, ampl_max) 

119 else: 

120 unwrap_trace(data, thresh, ampl_max)