Aggiunto piccolo modulo in fortran per calcolare più velocemente

Leonardo Robol [2010-03-22 17:24]
Aggiunto piccolo modulo in fortran per calcolare più velocemente
la decomposizione e la ricomposizione.
Filename
Filtering/Filtering.py
Filtering/Makefile
Filtering/fast_filters.f90
RefinementEquation/Iteration.py
diff --git a/Filtering/Filtering.py b/Filtering/Filtering.py
index ff85515..f4de43f 100644
--- a/Filtering/Filtering.py
+++ b/Filtering/Filtering.py
@@ -3,6 +3,8 @@
 #

 from numpy import array, zeros, convolve, dot, roll, sqrt, concatenate
+from numpy.linalg import norm
+from fast_filters import filter_and_downsample, upsample_and_filter

 class AbstractFilter():

@@ -21,7 +23,7 @@ class FIR(AbstractFilter):

     def __init__(self, coefficients):

-        self.coefficients = coefficients
+        self.coefficients = array(coefficients)

     def __call__(self, samples):
         """Apply the FIR to samples"""
@@ -53,6 +55,33 @@ def UpSample(samples):
     s[::2] = samples
     return array(s)

+
+def PRCheck(f):
+    """
+    Return the delay that the filterbank will cause or raise
+    RuntimeError if PR condition is not satisfied"""
+    a = 0.5 * (f.lowPassFilter * f.lowPassInverseFilter).GetResponse ()
+    b = 0.5 * (f.highPassFilter * f.highPassInverseFilter).GetResponse ()
+    f0 = f.lowPassInverseFilter.GetResponse()
+    f1 = f.highPassInverseFilter.GetResponse()
+    f0[1::2] = (-1) * f0[1::2]
+    f1[1::2] = (-1) * f1[1::2]
+    c = (f.lowPassFilter * FIR(f0)).GetResponse()
+    d = (f.highPassFilter * FIR(f1)).GetResponse()
+
+    # Ora cerco di capire quanto sarà il ritardo.
+    for i in range(0,len(a+b)):
+        if (a+b)[i] > 0.5:
+            break
+    s = zeros(len(a+b))
+    s[i] = 1
+    if (norm(a+b - s) > 1e-15):
+        raise RuntimeError('PR condition is not satisfied')
+    if (norm(c+d) > 1e-15):
+        raise RuntimeError('PR condition on aliasing is not satisfied')
+    return i
+
+
 class WaveletStack():

     def __init__(self):
@@ -201,6 +230,7 @@ class FilterBank():
         """
         self.depth = depth

+
     def SetLowPassFilter(self, lowpass):
         """
         Set the LowPassFilter for the filterbank.
@@ -209,9 +239,6 @@ class FilterBank():
         """
         self.lowPassFilter = lowpass

-        # Sembra che la lunghezza del filtro debba essere
-        # la metà del filtro lowpass, arrotondata per eccesso
-        self.SetLength ( int((len(lowpass) + 1)/2) )

     def SetHighPassFilter(self, highpass):
         """
@@ -235,12 +262,17 @@ class FilterBank():
         """
         self.highPassInverseFilter = highpass

-    def SetLength(self, length):
+    def SetLength(self, length = None):
         """
         Set the length of the filter, i.e how much
         will be the delay when recovering
         """
-        self.length = length
+        if length is None:
+            self.length = PRCheck(self)
+        else:
+            self.length = length
+        # print "length set to %d" % self.length
+

     def EndEffectLength(self):
         """Cerchiamo di dare una valutazione di quanto gli effetti
@@ -268,8 +300,11 @@ class FilterBank():
         # Do the real filtering and downsampling. Store the downsampled
         # details in the array.
         for recursion in xrange(0, self.depth):
-            samplesStack.PushHighSamples (DownSample (self.highPassFilter (low)))
-            low = DownSample (self.lowPassFilter (low))
+            # samplesStack.PushHighSamples (DownSample (self.highPassFilter (low)))
+            # low = DownSample (self.lowPassFilter (low))
+            samplesStack.PushHighSamples (
+                filter_and_downsample(low, 2, self.highPassFilter.GetResponse()))
+            low = filter_and_downsample(low, 2, self.lowPassFilter.GetResponse())

         # In the end append the low downsampled data to the array
         samplesStack.PushLowSamples (low)
@@ -294,9 +329,11 @@ class FilterBank():
             high = samplesStack.PopHighSamples ()

             # E li filtriamo insieme ai low samples.
-            low = self.lowPassInverseFilter (UpSample (low))
+            # low = self.lowPassInverseFilter (UpSample (low))
+            low = upsample_and_filter(low, 2, self.lowPassInverseFilter.GetResponse())

-            low += self.highPassInverseFilter (UpSample (high))
+            # low += self.highPassInverseFilter (UpSample (high))
+            low += upsample_and_filter(high, 2, self.highPassInverseFilter.GetResponse())

             # Facciamo shiftare l'array in modo che il delay dei sample
             # ricostruiti non disturbi la ricostruzione dei prossimi.
@@ -329,14 +366,6 @@ HaarFilterBank.SetFilterMatrix ( [
         ])
 HaarFilterBank.SetLength(1)

-LeoFilterBank = FilterBank()
-LeoFilterBank.SetFilterMatrix ( [
-        [0.25 , 0.5 , 0.25 ],
-        [0.25 , -0.5, 0.25] ,
-        [0.5 , 1, 0.5],
-        [-0.5, 1, -0.5]
-        ])
-LeoFilterBank.SetLength(2)

 StrangFilterBank = FilterBank ()
 StrangFilterBank.SetFilterMatrix ( [
@@ -346,12 +375,7 @@ StrangFilterBank.SetFilterMatrix ( [
         0.25 * array([1,2,-6,2,1])
         ])
 StrangFilterBank.SetLength(3)
-
-def PRCheck(f):
-
-    a = 0.5 * (f.lowPassFilter * f.lowPassInverseFilter).GetResponse ()
-    b = 0.5 * (f.highPassFilter * f.highPassInverseFilter).GetResponse ()
-    return a + b
+


 h0 = StrangFilterBank.lowPassFilter
diff --git a/Filtering/Makefile b/Filtering/Makefile
new file mode 100644
index 0000000..fa1a0cc
--- /dev/null
+++ b/Filtering/Makefile
@@ -0,0 +1,11 @@
+F2PY=f2py
+MODULE_NAME=fast_filters
+SOURCE_FILES=fast_filters.f90
+
+all: fast_filters.so
+
+fast_filters.so:
+	$(F2PY) -c -m $(MODULE_NAME) $(SOURCE_FILES)
+
+clean:
+	rm -f fast_filters.so
diff --git a/Filtering/fast_filters.f90 b/Filtering/fast_filters.f90
new file mode 100644
index 0000000..08ad24e
--- /dev/null
+++ b/Filtering/fast_filters.f90
@@ -0,0 +1,72 @@
+SUBROUTINE filter_and_downsample(output, samples, downsample, filter, n, m)
+
+  IMPLICIT NONE
+  ! Variabili
+  INTEGER m, n
+  DOUBLE PRECISION, DIMENSION(n) :: samples
+  DOUBLE PRECISION, DIMENSION(m) :: filter
+
+  INTEGER downsample
+  INTEGER k, i, s
+  DOUBLE PRECISION, DIMENSION(n/downsample) :: output
+
+!F2PY INTENT(IN) :: samples
+!F2PY INTENT(IN) :: filter
+!F2PY INTENT(IN) :: downsample
+!F2PY INTENT(HIDE) :: n
+!F2PY INTENT(HIDE) :: m
+!F2PY INTENT(OUT) :: output
+
+  ! Cominciamo a filtrare
+  s = 1
+
+  ! Applichiamo il filtro al sample k-esimo
+  DO k = 1,n,downsample
+     output(s) = 0
+     i = 0
+     DO WHILE( i < m .and. i < k )
+        output(s) = output(s) + filter(i + 1)*samples(k - i)
+        i = i + 1
+     END DO
+
+     ! Passiamo ai prossimi sample
+     s = s + 1
+
+  END DO
+
+END SUBROUTINE
+
+SUBROUTINE upsample_and_filter(output, samples, upsample, filter, n, m)
+
+  ! dichiarazioni
+  INTEGER :: m,n, upsample, i,j,s
+  DOUBLE PRECISION, DIMENSION(n) :: samples
+  DOUBLE PRECISION, DIMENSION(m) :: filter
+  DOUBLE PRECISION, DIMENSION(upsample * n) :: output
+
+  !F2PY INTENT(IN) samples
+  !F2PY INTENT(IN) filter
+  !F2PY INTENT(IN) upsample
+  !F2PY INTENT(HIDE) m
+  !F2PY INTENT(HIDE) n
+  !F2PY INTENT(OUT) output
+
+  s = 1
+  DO i = 1, n
+     ! Calcolo l'elemento in posizione s e s+1
+     output(s) = 0
+     output(s+1) = 0
+     j = 0
+
+     ! In questo ciclo calcoliamo la convoluzione del filtro
+     ! con il vettore upsampled sfruttando l'informazione che
+     ! nei posti dispari c'è solo 0.
+     DO WHILE(j < m .and. j < s)
+        output(s)     = output(s) + samples(i - j/2) * filter(j + 1)
+        output(s + 1) = output(s + 1) + samples(i - j/2) * filter(j + 2)
+        j = j + 2
+     END DO
+     s = s + 2
+  END DO
+
+END SUBROUTINE
diff --git a/RefinementEquation/Iteration.py b/RefinementEquation/Iteration.py
index 5053e10..ba02797 100755
--- a/RefinementEquation/Iteration.py
+++ b/RefinementEquation/Iteration.py
@@ -14,8 +14,9 @@ print "ok"
 # Si comincia ad iterare
 h = numpy.array([0.125 , 0.25, 0.25, 0.25, 0.125])
 h = Filtering.DaubechiesFilterBank.lowPassFilter.GetResponse()
+h = Filtering.StrangFilterBank.lowPassFilter.GetResponse()
 # h = Filtering.StrangFilterBank.lowPassFilter.GetResponse()
-t = numpy.linspace(-0.5,len(h) + 0.05,1000)
+t = numpy.linspace(-0.5,len(h) + 0.05,100)

 def box(x):
     """box function"""
@@ -49,8 +50,8 @@ hold(False)
 diff = 1
 try:
     while True:
-        # plot(t,phi)
-        # draw ()
+        plot(t,phi)
+        draw ()
         print "diff = %f" % diff
         time.sleep (1)
         newphi = map(refinement, t)
ViewGit