#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
from optparse import OptionParser

def Output (string):
    """Output with colors :)"""
    print "\033[32;1m===>\033[0m %s" % string

def StartProgram():
    """Starting banner"""
    print "\033[31;1m===>\033[0m Discrete Wavelet transform started"

def EndProgram():
    """End banner"""
    print "",

def LoadingLibrariesStarted():
    """Loading libraries banner"""
    if __name__ == "__main__":
        print "\033[31;1m===>\033[0m Loading numeric libraries...",
        sys.stdout.flush ()

def LoadingLibrariesFinished():
    """Loading libraries finished banner"""
    if __name__ == "__main__":
        print "done"

if __name__ == "__main__":

    parser = OptionParser (usage = "usage: %prog [options] filename")
    parser.add_option("-r", "--rebuild", dest="rebuild",
                      default=False, action="store_true",
                      help="Make DWT and then IDWT")
    parser.add_option("-w", "--write", dest="filewrite", default='rebuilt.raw',
                      help="Write reconstructed samples to this file")
    parser.add_option("-s", "--show", dest="show",
                      default=True, action="store_true",
                      help="Show the decomposed waves (this is the default)")
    parser.add_option("-d", "--depth", dest="depth",
                      default=4, help="Set the recursion level of the filter bank (default is 4)")
    parser.add_option("-b", "--filterbank", dest="filterbank", default='haar',
                      help="Set the filterbank to use in the transform. Valid inputs are 'haar', 'daubechies', 'D4', 'strang'")

    (options, args) = parser.parse_args ()

        filename = args[0]
        parser.error ("Please a specify a PCM file to read the samples from")

    if (not options.show) and (not options.rebuild):


# Importing libraries
import Filtering
from pylab import show, plot, title, xlabel, ylabel, rcParams
from numpy import array, sqrt, memmap, roll
from numpy.linalg import norm
import time

params = {
    "text.usetex": True,
    'font.family': 'serif',



class DWT():

    def __init__(self, filename, action = 'show', filewrite = 'rebuilt.wav',
                 filterbank = 'haar', depth = 4):

        StartProgram ()

        startingTime = time.time ()
        self.depth = depth

        self.filterBankName = ""

        # Scelgo la filterbank da utilizzare
        if filterbank == 'haar':
            filterBank = Filtering.HaarFilterBank
            self.filterBankName = "Haar"
        elif (filterbank == 'daubechies') or (filterbank.lower() == 'd4'):
            filterBank = Filtering.DaubechiesFilterBank
            self.filterBankName = "Daubechies D4"
        elif filterbank == 'strang':
            filterBank = Filtering.StrangFilterBank
            self.filterBankName = "Strang"
        elif filterbank == 'leo':
            filterBank = Filtering.LeoFilterBank
            self.filterBankName = "Leo"
            filterBank = Filtering.HaarFilterBank
            Output ("FilterBank %s not known. Setting 'haar'" % filterbank)

        filterBank.SetDepth (int(depth))

        samples = self.LoadSamples (filename)
        wavelets = filterBank.Split (samples)

        Output ("Decomposed in %f seconds" % (time.time() - startingTime))
        Output ("Wavelet size: %d bytes" % (2*wavelets.GetAllSamplesNumber()))

        # Mostro la decomposizione se l'utente l'ha chiesto
        if action == 'show':
            self.Show (wavelets)

        if action is 'rebuild':
            startingTime = time.time ()
            rebuilt = filterBank.Rebuild (wavelets)
            Output ("Rebuilt in %f seconds" % (time.time() - startingTime))

            # Se la differenza in norma è più di 10^-8 possiamo preoccuparci.
            a = norm(rebuilt - samples)
            if (a > 1E-2):
                Output ("Error while reconstructing. Rebuilt samples differs from original ones")
                Output ("||rebuilt - samples|| = %f" % a)
                Output ("There is likely an error in the code")
            elif (a > 1E-6):
                Output ("Error while reconstructing. Rebuilt samples differs from original ones")
                Output ("This is likely an approximation error (the error is quite small)")
                Output ("Perfect reconstruction succeeded")
            self.WriteSamples(rebuilt, filewrite)

            EndProgram ()

    def LoadSamples(self, filename):
        Load the samples from an audio file
        samples = memmap (filename,
        Output("Loaded %d samples from %s" % (len(samples), filename))
        return samples

    def WriteSamples(self, samples, filename):
        Output("Writing samples to %s" % filename)
        data = memmap (filename,
                       shape = len(samples))
        data[:] = samples[:]
        data.flush ()

    def Show(self, wavelets):
        Shows the result of filtering

        # We set the frequency to have seconds (and not samples)
        # in the x-axis of the plot.
        frequency = float (44100)

        # We choose a decreasing scale to sync all the samples
        # because they are recursively downsamples by a factor
        # of two and we want to plot with the same time-scale.
        scale = pow(2, wavelets.GetNumSamples ())

        singleOffset = 2 * wavelets.GetSamplesMaxValue()
        offset = -(self.depth / 2) * singleOffset

        # We plot only the first 60 seconds of audio, to avoid memory
        # being flooded with our data :)
        toPlot = int(frequency) * 60

        # Stampo i low
        scale = int(0.5 * scale)
        low = wavelets.PopLowSamples()
        data = low[:toPlot / scale]

        axes = array(range(0, len(data) * scale, scale)) / frequency

        plot(axes, data + offset)

        offset += singleOffset

        while (wavelets.GetHighSamplesNumber() > 0):

            samples = wavelets.PopHighSamples ()

            data = samples[0:toPlot / scale]
            axes = array(range(0, len(data) * scale , scale)) / frequency

            plot (axes, data + offset)
            offset += singleOffset
            scale = int(0.5*scale)

        # Set some nice text
        title (r"Decomposition using %s filter bank" % self.filterBankName)
        xlabel (r"time (s)")

        show ()

if __name__ == "__main__":

    # Scegliamo cosa fare, a seconda delle opzioni di cui
    # abbiamo fatto il parsing più in alto.
    # Partiamo.

    if options.rebuild:
        DWT(filename = filename, action = 'rebuild',
            filewrite = options.filewrite, depth = options.depth,
            filterbank = options.filterbank)

    elif options.show:
        DWT(filename = filename, action = 'show',
            depth = options.depth, filterbank = options.filterbank)
