#-*- coding: utf-8 -*-
'''
@author: David Vilares Calvo
'''
import codecs
import re

from miopia.preprocessor.PreProcessorI import PreProcessorI
from miopia.util.exceptions.LanguageNotSupportedException import LanguageNotSupportedException

class PreProcessor(PreProcessorI):
    '''
    Tools for preprocessing a plain text
    '''
    
    RE_CURRENCY_SYMBOL = r'[\$€£]'
    RE_CURRENCY_CODE = r'[A-Z]{3}'
    
    decimal_mark = '\.'
    digit_grouping = ','
    
    
    def _prepare_regexps(self):
        r_decimals = r''+self.decimal_mark+r'[0-9]+'
        r_ordinals = r'(?:st|nd|rd|th)'
        self.r_numbers = r'-?[0-9]+(?:'+self.digit_grouping+r'[0-9]{3})*(?:'+r_decimals+')?'+r_ordinals+'?'
        self.r_numbers_std = r'-?[0-9]+(?:\.[0-9]+)?'
        self.r_time = r'[0-9]+[\.:][0-9]{2}h'
        
        
    def _convert_numbers(self,text):
        """
        Removes digit grouping and spaces currency symbols and codes
        """
        def __format_matched_numbers(match):
            return " "+match.group(1).replace(self.digit_grouping,"").replace(self.decimal_mark,".")+" "
        
        
        #TODO this line is splitting numbers from words
        #text = re.sub('('+self.r_numbers+')', __format_matched_numbers, text)
        
        
        #re-join hour
        text = re.sub(r'([0-9]+) [\.:] ([0-9]{2}) h(?:\s|$)',r'\1.\2 ', text)
        
        #re-join ordinals
        #text = re.sub(r'([2-9]+) st\s')
         
        text= re.sub('('+self.RE_CURRENCY_SYMBOL+')([0-9]+)',r'\1 \2', text)
        text= re.sub('('+self.RE_CURRENCY_CODE+')([0-9]+)',r'\1 \2', text)
        text= re.sub('([0-9]+)('+self.RE_CURRENCY_SYMBOL+')',r'\1 \2', text)
        text= re.sub('([0-9]+)('+self.RE_CURRENCY_CODE+')',r'\1 \2', text)
        
        return text
        
    def _is_number(self,text):
        return re.match('^'+self.r_numbers_std+'$',(text)) != None
   
    def __init__(self, composite_words={}, abbreviations={}, lang='es'):
        '''
        Constructor
        @param composite_words: A composite words dictionary in the format {OrinalWord:JoinedWord}
        @param abbreviations: An abbreviations dictionary in the format {abbreviation:OriginalWord}
        '''
        self.lang = lang

        if(lang == 'es'):
            self.decimal_mark=','
            self.digit_grouping='\.'
        elif(lang == 'en'):
            self.decimal_mark='\.'
            self.digit_grouping=','
        else:
            raise LanguageNotSupportedException(lang)
        
        self._prepare_regexps()
        
        self._composite_words = composite_words
        self._abbreviations = abbreviations
        self._composite_words_patterns = self._get_composite_words_patterns(self._composite_words)
        self._abbreviations_patterns = self._get_abbreviations_patterns(self._abbreviations)
        self._special_abbreviations_patterns = self._get_special_abbreviations_patterns(self._abbreviations)
    
    
    def _get_composite_words_patterns(self,dict_composite_words):
        dict_composite_words_patterns = {}
        keys = self._composite_words.keys()
        for key in keys:
            pattern = re.compile('[ .,;:¡¿!?\[]+'.decode('utf-8')+key+'[ .,;:¡¿!?\]]+'.decode('utf-8')+'|'
                                +'[ .,;:¡¿!?\]]+'.decode('utf-8')+key+'$'+'|'
                                +'^'+key+'[ .,;:¡¿!?\]]+'.decode('utf-8')+'|'
                                +'\\b'+key+'\\b', re.IGNORECASE)
            dict_composite_words_patterns[key] = pattern
        return dict_composite_words_patterns
            
            
            
    def _get_abbreviations_patterns(self,dict_abbreviations):
        #TODO: Bug if the line is exactly an abbreviation
        dict_abbreviations_patterns = {}
        abbreviations = self._abbreviations.keys()
        for abbr in abbreviations:
            try:
                aux_abbr = abbr.replace('(','\(').replace(')','\)').replace('[','\[').replace(']','\]')
                pattern = re.compile('[ .,;:¡¿!?\[]+'.decode('utf-8')+aux_abbr+'[ .,;:¡¿!?\]]+'.decode('utf-8')+'|'
                                    +'[ .,;:¡¿!?\]]+'.decode('utf-8')+aux_abbr+'$'+'|'
                                    +'^'+aux_abbr+'[ .,;:¡¿!?\]]+'.decode('utf-8'), re.IGNORECASE)
                dict_abbreviations_patterns[abbr] = pattern
            except:
                pass
        return dict_abbreviations_patterns
        
        
    def _get_special_abbreviations_patterns(self,dict_abbreviations):
        dict_abbreviations_patterns = {}
        abbreviations = self._abbreviations.keys()
        for abbr in abbreviations:
            if type(abbr) != type(u''):
                abbr = abbr.decode('utf-8')
            pattern = re.compile('[ .,;:¡¿!?\[]+'.decode('utf-8')+re.escape(abbr)+'[ .,;:¡¿!?\]]+'.decode('utf-8')+'|'
                                +'[ .,;:¡¿!?\]]+'.decode('utf-8')+re.escape(abbr)+'$'+'|'
                                 +'^'+re.escape(abbr)+'[ .,;:¡¿!?\]]+'.decode('utf-8'), re.IGNORECASE)
            dict_abbreviations_patterns[abbr] = pattern
        return dict_abbreviations_patterns
    
    def _format_punkt(self,token):
        """
        @param token: A token
        @return: A modified token with separated punkt, if is not a number, otherwise returns the token
        """
        lpunkt = [".",",",";",":","¡","¿"]
        if not self._is_number(token):
            if type(token) == str:
                # Ignore errors even if the string is not proper UTF-8 or has
                # broken marker bytes.
                # Python built-in function unicode() can do this.
                token = unicode(token, "utf-8", errors="ignore")
            else:
                # Assume the value object has proper __unicode__() method
                token = unicode(token)
            #print "token",type(token)
            #try:
            #  token.decode('utf-8')
            #except UnicodeDecodeError:
            #  token = token.encode('utf-8')
            #except UnicodeEncodeError:
            #  pass
            #print "token",type(token)
            #token = token.encode('utf-8')
            token = token.replace('&ldquo;','\"').replace('&rdquo;','\"')
            #Special quotes normalization
            #Decode is necessary because there are non ASCII  chars
            token = token.replace(u"“","\"").replace(u"”","\"")       
            #It is processed already well by the parser
            if "..." in token:
                return token
            else:
                for p in lpunkt:
                    if not token.endswith(p.decode("utf-8")):
                        token = token.replace(p.decode("utf-8"),p.decode("utf-8")+" ") 
                if '.' in token and not '.' == token:
                    token = token.replace("."," .")              
                return token       
        else:
            return token
        


    def _format_composite_words(self,line):
        """
        @param line: A line of a sentence
        @return: A line where composite words are joined as one token
        """
        keys = self._composite_words.keys()
        keys_in_line = [key for key in keys if key in line]
        for key in keys_in_line:
            composed_expressions = self._composite_words_patterns[key].findall(line)
            for c in composed_expressions:
                line = line.replace(c,c.replace(key,self._composite_words.get(key)))
        return line

   

    def _format_upper_abbreviations(self,line,a,abbr): 
        #Is the abbreviation but with different capitalisation

        if abbr not in a:          
            return line.replace(a,a.lower()).replace(a.lower(),a.lower().replace(abbr,self._abbreviations.get(abbr)))
        else:
            return line.replace(a,a.replace(abbr,self._abbreviations.get(abbr)))        
    
    
    def _format_abbreviations(self,line):

        abbreviations = self._abbreviations.keys()
        abbreviations_in_line = [abbr for abbr in abbreviations if abbr in line]
        
        for abbr in abbreviations_in_line:
            try:
                abbreviations_found = set(self._abbreviations_patterns[abbr].findall(line))
            except  KeyError:
                abbreviations_found = set(self._special_abbreviations_patterns[abbr].findall(line))
            for abbreviation_found in abbreviations_found:
                line = self._format_upper_abbreviations(line,abbreviation_found,
                                                        abbr)
        return line
        
        
        
    def preprocess(self, text):
        """
        @param text: A string
        @return: A string preprocessed
        """
            
        tokens = []
        
        aux = self._convert_numbers(text).split()
        
        #For each word
        for a in aux:
            tokens.append(self._format_punkt(a))
        return self._format_composite_words(self._format_abbreviations((' '.join(tokens)))) 


