Audiofiltering:

Increase efficiency of convolution algorithm.

Approximately 150x faster. (takes ~4.5s to lowpass filter ~1min of audio)
This commit is contained in:
David Sutton 2022-12-28 14:32:24 +10:30
parent b2d2a41fdc
commit 52a56f3a52
2 changed files with 120 additions and 21 deletions

View File

@ -16,21 +16,60 @@ func main() {
start := time.Now() start := time.Now()
// Read the audio data from the file. // Read the audio data from the file.
input, _ := os.ReadFile("sine.pcm") input, _ := os.ReadFile("whitenoise.pcm")
// Create Buffer in struct format defined in pcm.go // Create Buffer in struct format defined in pcm.go
format := pcm.BufferFormat{Rate: 44100, Channels: 1, SFormat: pcm.S16_LE} format := pcm.BufferFormat{Rate: 44100, Channels: 1, SFormat: pcm.S16_LE}
buf := pcm.Buffer{Format: format, Data: input} buf := pcm.Buffer{Format: format, Data: input}
// Create a filter. // Create a filter.
amp := pcm.Filter{BuffInfo: buf.Format, Type: pcm.AMPLIFIER, Upper: 3, Taps: 50} bs := pcm.Filter{BuffInfo: buf.Format, Type: pcm.BANDSTOP, Lower: 1000, Upper: 2000, Taps: 500}
// Apply the lowpass filter to the buffer. // Apply different filters to save and compare.
output := amp.Apply(buf) bs.Generate()
bs_1khz := bs.Apply(buf)
fmt.Println("Applied 1Khz Bandstop filter")
bs.Lower, bs.Upper = 2000, 5000
bs.Generate()
bs_2khz := bs.Apply(buf)
fmt.Println("Applied 2Khz Bandstop filter")
bs.Lower, bs.Upper = 5000, 10000
bs.Generate()
bs_5khz := bs.Apply(buf)
fmt.Println("Applied 5Khz Bandstop filter")
bs.Lower, bs.Upper = 10000, 15000
bs.Generate()
bs_10khz := bs.Apply(buf)
fmt.Println("Applied 10Khz Bandstop filter")
bs.Lower, bs.Upper = 15000, 18000
bs.Generate()
bs_15khz := bs.Apply(buf)
fmt.Println("Applied 15Khz Bandstop filter")
// Save the transformed audio. // Save the transformed audio.
f, _ := os.Create("output.pcm") f, _ := os.Create("bs_1khz.pcm")
f.Write(output) f.Write(bs_1khz)
fmt.Println("Wrote audio to bs_1khz.pcm")
f, _ = os.Create("bs_2khz.pcm")
f.Write(bs_2khz)
fmt.Println("Wrote audio to bs_2khz.pcm")
f, _ = os.Create("bs_5khz.pcm")
f.Write(bs_5khz)
fmt.Println("Wrote audio to bs_5khz.pcm")
f, _ = os.Create("bs_10khz.pcm")
f.Write(bs_10khz)
fmt.Println("Wrote audio to bs_10khz.pcm")
f, _ = os.Create("bs_15khz.pcm")
f.Write(bs_15khz)
fmt.Println("Wrote audio to bs_15khz.pcm")
// Display execution time. // Display execution time.
fmt.Println("Finished execution. Total time:", time.Since(start)) fmt.Println("Finished execution. Total time:", time.Since(start))

View File

@ -2,9 +2,11 @@ package pcm
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"math" "math"
"sync" "sync"
"github.com/mjibson/go-dsp/fft"
"github.com/mjibson/go-dsp/window" "github.com/mjibson/go-dsp/window"
) )
@ -54,7 +56,7 @@ func (filter *Filter) Generate() {
// Create a lowpass filter with characteristics from struct. // Create a lowpass filter with characteristics from struct.
size := filter.Taps + 1 size := filter.Taps + 1
filter.Coeffs = make([]float64, size, size) filter.Coeffs = make([]float64, size, size)
fd := (filter.Upper * 0.7) / float64(filter.SampleRate) fd := (filter.Upper) / float64(filter.SampleRate)
b := (2 * math.Pi) * fd b := (2 * math.Pi) * fd
winData := window.FlatTop(size) winData := window.FlatTop(size)
for n := 0; n < (filter.Taps / 2); n++ { for n := 0; n < (filter.Taps / 2); n++ {
@ -70,7 +72,7 @@ func (filter *Filter) Generate() {
// Create a HighIGHPASSass filter with characteristics from struct. // Create a HighIGHPASSass filter with characteristics from struct.
size := filter.Taps + 1 size := filter.Taps + 1
filter.Coeffs = make([]float64, size, size) filter.Coeffs = make([]float64, size, size)
fd := (filter.Lower + 3000) / float64(filter.SampleRate) fd := (filter.Lower) / float64(filter.SampleRate)
b := (2 * math.Pi) * fd b := (2 * math.Pi) * fd
winData := window.FlatTop(size) winData := window.FlatTop(size)
for n := 0; n < (filter.Taps / 2); n++ { for n := 0; n < (filter.Taps / 2); n++ {
@ -85,31 +87,31 @@ func (filter *Filter) Generate() {
// Make Low and HighIGHPASSass filters. // Make Low and HighIGHPASSass filters.
lowpass := Filter{Type: LOWPASS, SampleRate: filter.SampleRate, Upper: filter.Upper, Taps: filter.Taps} lowpass := Filter{Type: LOWPASS, SampleRate: filter.SampleRate, Upper: filter.Upper, Taps: filter.Taps}
hIGHPASS := Filter{Type: HIGHPASS, SampleRate: filter.SampleRate, Lower: filter.Lower, Taps: filter.Taps} highpass := Filter{Type: HIGHPASS, SampleRate: filter.SampleRate, Lower: filter.Lower, Taps: filter.Taps}
lowpass.Generate() lowpass.Generate()
hIGHPASS.Generate() highpass.Generate()
// Convolve lowpass filter with highIGHPASSass filter to get bandpass filter. // Convolve lowpass filter with highIGHPASSass filter to get bandpass filter.
var wg sync.WaitGroup var wg sync.WaitGroup
ch := make(chan []float64) ch := make(chan []float64, 1)
wg.Add(1) wg.Add(1)
go Convolve(lowpass.Coeffs, hIGHPASS.Coeffs, &wg, ch) go FastConvolve(lowpass.Coeffs, highpass.Coeffs, &wg, ch)
wg.Wait() wg.Wait()
filter.Coeffs = <-ch filter.Coeffs = <-ch
case BANDSTOP: case BANDSTOP:
// Make Low and HighIGHPASSass filters. // Make Low and Highpass filters.
lowpass := Filter{Type: LOWPASS, SampleRate: filter.SampleRate, Upper: filter.Lower, Taps: filter.Taps} lowpass := Filter{Type: LOWPASS, SampleRate: filter.SampleRate, Upper: filter.Lower, Taps: filter.Taps}
hIGHPASS := Filter{Type: HIGHPASS, SampleRate: filter.SampleRate, Lower: filter.Upper, Taps: filter.Taps} highpass := Filter{Type: HIGHPASS, SampleRate: filter.SampleRate, Lower: filter.Upper, Taps: filter.Taps}
lowpass.Generate() lowpass.Generate()
hIGHPASS.Generate() highpass.Generate()
// Add lowpass filter to highIGHPASSass filter to get bandstop filter. // Add lowpass filter to highpass filter to get bandstop filter.
size := filter.Taps + 1 size := filter.Taps + 1
filter.Coeffs = make([]float64, size, size) filter.Coeffs = make([]float64, size, size)
for i := range lowpass.Coeffs { for i := range lowpass.Coeffs {
filter.Coeffs[i] = lowpass.Coeffs[i] + hIGHPASS.Coeffs[i] filter.Coeffs[i] = lowpass.Coeffs[i] + highpass.Coeffs[i]
} }
} }
@ -149,7 +151,7 @@ func (filter *Filter) Apply(b Buffer) []byte {
var wg sync.WaitGroup var wg sync.WaitGroup
ch := make(chan []float64, 1) ch := make(chan []float64, 1)
wg.Add(1) wg.Add(1)
go Convolve(inputAsFloat, filter.Coeffs, &wg, ch) go FastConvolve(inputAsFloat, filter.Coeffs, &wg, ch)
wg.Wait() wg.Wait()
convolution = <-ch convolution = <-ch
} }
@ -158,9 +160,14 @@ func (filter *Filter) Apply(b Buffer) []byte {
var output []byte var output []byte
buf := make([]byte, 2) buf := make([]byte, 2)
for i := range convolution { for i := range convolution {
if convolution[i] >= 1 {
convolution[i] = 0.9999
} else if convolution[i] <= -1 {
convolution[i] = -0.9999
}
convolution[i] = convolution[i] * 32767 convolution[i] = convolution[i] * 32767
if convolution[i] < 0 { if convolution[i] < 0 {
convolution[i] = convolution[i] + 32768*2 convolution[i] = convolution[i] + 32767*2
} }
binary.LittleEndian.PutUint16(buf[:], uint16(convolution[i])) binary.LittleEndian.PutUint16(buf[:], uint16(convolution[i]))
output = append(output, buf[0], buf[1]) output = append(output, buf[0], buf[1])
@ -170,6 +177,7 @@ func (filter *Filter) Apply(b Buffer) []byte {
} }
// Convolve takes in a signal and an FIR filter and computes the convolution. (runs in O(n^2) time)
func Convolve(x, h []float64, wg *sync.WaitGroup, ch chan []float64) { func Convolve(x, h []float64, wg *sync.WaitGroup, ch chan []float64) {
// Create a waitgroup to be used in goroutines called by Convolution // Create a waitgroup to be used in goroutines called by Convolution
@ -178,9 +186,10 @@ func Convolve(x, h []float64, wg *sync.WaitGroup, ch chan []float64) {
// Compute the convolution // Compute the convolution
convLen := len(x) + len(h) - 1 convLen := len(x) + len(h) - 1
y := make([]float64, convLen) y := make([]float64, convLen)
var progress int
for n := 0; n < convLen; n++ { for n := 0; n < convLen; n++ {
convwg.Add(1) convwg.Add(1)
go func(n int, y []float64, convwg *sync.WaitGroup) { go func(n int, y []float64, convwg *sync.WaitGroup, progress *int) {
var sum float64 = 0 var sum float64 = 0
for k := 0; k < len(x); k++ { for k := 0; k < len(x); k++ {
if n-k >= 0 && n-k < len(h) { if n-k >= 0 && n-k < len(h) {
@ -190,14 +199,65 @@ func Convolve(x, h []float64, wg *sync.WaitGroup, ch chan []float64) {
} }
} }
y[n] = sum y[n] = sum
*progress++
convwg.Done() convwg.Done()
}(n, y, &convwg) }(n, y, &convwg, &progress)
fmt.Println(float64(progress) * 100 / float64(convLen))
} }
convwg.Wait()
ch <- y ch <- y
close(ch) close(ch)
wg.Done() wg.Done()
} }
// FastConvolve takes in a signal and an FIR filter and computes the convolution. (runs in O(nlog(n)) time)
func FastConvolve(x, h []float64, wg *sync.WaitGroup, ch chan []float64) {
// Calculate the length of the linear convolution
convLen := len(x) + len(h) - 1
// Pad signals to the next largest power of 2 larger than convLen
padLen := int(math.Pow(2, math.Ceil(math.Log2(float64(convLen)))))
zeros := make([]float64, padLen-len(x), padLen-len(h))
x = append(x, zeros...)
zeros = make([]float64, padLen-len(h))
h = append(h, zeros...)
// Compute DFFTs
X, H := fft.FFTReal(x), fft.FFTReal(h)
// Compute the multiplication of the two signals in the freq domain
var convWG sync.WaitGroup
Y := make([]complex128, padLen)
for i := range X {
convWG.Add(1)
go func(a, b complex128, y *complex128, convWG *sync.WaitGroup, i int) {
*y = a * b
convWG.Done()
}(X[i], H[i], &Y[i], &convWG, i)
}
convWG.Wait()
// Compute the IDFFT
iy := fft.IFFT(Y)
// Convert to []float64
y := make([]float64, padLen)
for i := range iy {
convWG.Add(1)
go func(a complex128, y *float64, convWG *sync.WaitGroup) {
*y = real(a)
convWG.Done()
}(iy[i], &y[i], &convWG)
}
convWG.Wait()
// Trim to length of linear convolution
ch <- y[0:convLen]
wg.Done()
}
// LEGACY FUNCTION FOR TESTING // LEGACY FUNCTION FOR TESTING
// Max returns the absolute highest value in a given array. // Max returns the absolute highest value in a given array.
func Max(a []float64) float64 { func Max(a []float64) float64 {