#!/usr/bin/env python
# -*- coding: utf-8 -*-
#

import sys
from optparse import OptionParser

def Output (string):
    """Output with colors :)"""
    print "\033[32;1m===>\033[0m %s" % string

def StartProgram():
    """Starting banner"""
    print "\033[31;1m===>\033[0m Discrete Wavelet transform started"

def EndProgram():
    """End banner"""
    print "",

def LoadingLibrariesStarted():
    """Loading libraries banner"""
    if __name__ == "__main__":
        print "\033[31;1m===>\033[0m Loading numeric libraries...",
        sys.stdout.flush ()

def LoadingLibrariesFinished():
    """Loading libraries finished banner"""
    if __name__ == "__main__":
        print "done"


if __name__ == "__main__":

    parser = OptionParser (usage = "usage: %prog [options] filename")
    parser.add_option("-r", "--rebuild", dest="rebuild",
                      default=False, action="store_true",
                      help="Make DWT and then IDWT")
    parser.add_option("-w", "--write", dest="filewrite", default='rebuilt.raw',
                      help="Write reconstructed samples to this file")
    parser.add_option("-s", "--show", dest="show",
                      default=True, action="store_true",
                      help="Show the decomposed waves (this is the default)")
    parser.add_option("-d", "--depth", dest="depth",
                      default=4, help="Set the recursion level of the filter bank (default is 4)")
    parser.add_option("-b", "--filterbank", dest="filterbank", default='haar',
                      help="Set the filterbank to use in the transform. Valid inputs are 'haar', 'daubechies', 'D4', 'strang'")


    (options, args) = parser.parse_args ()

    try:
        filename = args[0]
    except:
        parser.error ("Please a specify a PCM file to read the samples from")

    if (not options.show) and (not options.rebuild):
        exit

LoadingLibrariesStarted()

# Importing libraries
import Filtering
from pylab import show, plot, title, xlabel, ylabel, rcParams
from numpy import array, sqrt, memmap, roll
from numpy.linalg import norm
import time

params = {
    "text.usetex": True,
    'font.family': 'serif',
}

rcParams.update(params)



LoadingLibrariesFinished()


class DWT():

    def __init__(self, filename, action = 'show', filewrite = 'rebuilt.wav',
                 filterbank = 'haar', depth = 4):

        StartProgram ()

        startingTime = time.time ()
        self.depth = depth

        self.filterBankName = ""

        # Scelgo la filterbank da utilizzare
        if filterbank == 'haar':
            filterBank = Filtering.HaarFilterBank
            self.filterBankName = "Haar"
        elif (filterbank == 'daubechies') or (filterbank.lower() == 'd4'):
            filterBank = Filtering.DaubechiesFilterBank
            self.filterBankName = "Daubechies D4"
        elif filterbank == 'strang':
            filterBank = Filtering.StrangFilterBank
            self.filterBankName = "Strang"
        elif filterbank == 'leo':
            filterBank = Filtering.LeoFilterBank
            self.filterBankName = "Leo"
        else:
            filterBank = Filtering.HaarFilterBank
            Output ("FilterBank %s not known. Setting 'haar'" % filterbank)

        filterBank.SetDepth (int(depth))

        samples = self.LoadSamples (filename)
        wavelets = filterBank.Split (samples)

        Output ("Decomposed in %f seconds" % (time.time() - startingTime))
        Output ("Wavelet size: %d bytes" % (2*wavelets.GetAllSamplesNumber()))

        # Mostro la decomposizione se l'utente l'ha chiesto
        if action == 'show':
            self.Show (wavelets)


        if action is 'rebuild':
            startingTime = time.time ()
            rebuilt = filterBank.Rebuild (wavelets)
            Output ("Rebuilt in %f seconds" % (time.time() - startingTime))

            # Se la differenza in norma è più di 10^-8 possiamo preoccuparci.
            a = norm(rebuilt - samples)
            if (a > 1E-2):
                Output ("Error while reconstructing. Rebuilt samples differs from original ones")
                Output ("||rebuilt - samples|| = %f" % a)
                Output ("There is likely an error in the code")
            elif (a > 1E-6):
                Output ("Error while reconstructing. Rebuilt samples differs from original ones")
                Output ("This is likely an approximation error (the error is quite small)")
            else:
                Output ("Perfect reconstruction succeeded")
            self.WriteSamples(rebuilt, filewrite)

            EndProgram ()


    def LoadSamples(self, filename):
        """
        Load the samples from an audio file
        """
        samples = memmap (filename,
                          dtype="<h",
                          mode="r")
        Output("Loaded %d samples from %s" % (len(samples), filename))
        return samples

    def WriteSamples(self, samples, filename):
        Output("Writing samples to %s" % filename)
        data = memmap (filename,
                       dtype="<h",
                       mode="w+",
                       shape = len(samples))
        data[:] = samples[:]
        data.flush ()

    def Show(self, wavelets):
        """
        Shows the result of filtering
        """

        # We set the frequency to have seconds (and not samples)
        # in the x-axis of the plot.
        frequency = float (44100)

        # We choose a decreasing scale to sync all the samples
        # because they are recursively downsamples by a factor
        # of two and we want to plot with the same time-scale.
        scale = pow(2, wavelets.GetNumSamples ())

        singleOffset = 2 * wavelets.GetSamplesMaxValue()
        offset = -(self.depth / 2) * singleOffset

        # We plot only the first 60 seconds of audio, to avoid memory
        # being flooded with our data :)
        toPlot = int(frequency) * 60

        # Stampo i low
        scale = int(0.5 * scale)
        low = wavelets.PopLowSamples()
        data = low[:toPlot / scale]

        axes = array(range(0, len(data) * scale, scale)) / frequency

        plot(axes, data + offset)

        offset += singleOffset

        while (wavelets.GetHighSamplesNumber() > 0):

            samples = wavelets.PopHighSamples ()

            data = samples[0:toPlot / scale]
            axes = array(range(0, len(data) * scale , scale)) / frequency

            plot (axes, data + offset)
            offset += singleOffset
            scale = int(0.5*scale)


        # Set some nice text
        title (r"Decomposition using %s filter bank" % self.filterBankName)
        xlabel (r"time (s)")

        show ()




if __name__ == "__main__":

    # Scegliamo cosa fare, a seconda delle opzioni di cui
    # abbiamo fatto il parsing più in alto.
    # Partiamo.

    if options.rebuild:
        DWT(filename = filename, action = 'rebuild',
            filewrite = options.filewrite, depth = options.depth,
            filterbank = options.filterbank)

    elif options.show:
        DWT(filename = filename, action = 'show',
            depth = options.depth, filterbank = options.filterbank)




ViewGit