#   This file is part of Cupydon.
#
#    Cupydon  is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    Cupydon is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with Cupydon.  If not, see <http://www.gnu.org/licenses/>
#
#    
#    Copyright(c):    ASGA, Gocad Consortium 2009-2013
#

import math

def abs(x):
    if x >=0 : return x
    return -x


def cmpres(a, b):
    nbr = 0
    for a, b in zip(a, b) :
        if a!= b : 
            a1, a2 = a
            b1, b2 = b
            if a1 >= b1 and a1 <= b2 : nbr+=1
            elif a2 >= b1 and a2 <= b2 : nbr+=1
            elif b1 >= a1 and b1 <= a2 : nbr+=1
            else : nbr+=2
    return nbr

class DTW :
    max_path = 10
    class Options :
        max_path = None
    
    
    def __init__(self, main, sub, res = None, options=None):
        self.main = main
        self.sub = sub
        self.main_size = len(main)-1
        self.sub_size = len(sub)-1
        self.res = res
        if options :
            self.set_options(options)
            
        self.execute()
    
    def set_options(self, opt):
        if opt.max_path != None and opt.max_path >=1 :
            self.max_path = opt.max_path
    
    def init(self):pass
    def execute(self):
        self.init()
        self.cur = [((0, ()), )] *(self.main_size +1)
        samecol0  = True
        if bool(self.sub[0])  !=  bool(self.main[0]) : samecol0 = False
        for y in range(1, self.sub_size+1) :
            self.prev = self.cur
            self.cur = []
            
            for x in range(self.main_size+1) :
                #print x, y
                self.cur_x = x
                self.cur_res = []
                if y == self.sub_size and x>0 :
                    self.add_res(False, 1, 0,)
                
                if (x+y)&1 : samecol = not samecol0
                else : samecol =  samecol0
                
                self.test(x, y, samecol)
                self.cur_res.sort()
                rr = []
                done = set()
                for c, p in self.cur_res :
                    if p in done :continue
                    rr.append((c, p))
                    done.add(p)
                    if len(rr)>= self.max_path : break
                self.cur.append(tuple(rr))
                
                
                
        #print self.cur[-1]
        self.result = self.cur[-1] 
        
    def add_res(self, prev, dx, cost, p1=None, p2=None):
        if p1 is None : pp = ()
        elif p2 is None : pp = ((p1-1, p1-1), )
        else : pp = ((p1-1, p2-1), )
        
        if prev : col = self.prev
        else : col = self.cur
        x = self.cur_x-dx
        
        for cc, path in col[x] :
            self.cur_res.append((cc+cost, path+pp))
    
    def sub_val(self, x): return self.sub[x]
    def main_val(self, x): return self.main[x]
    
    
    def dump_result(self):
        for c, r in self.result : 
            print c, r, cmpres(r, self.res)
    
    def get_err(self):
        return cmpres(self.result[0][1], self.res)
    def get_err2(self):
        return min(cmpres(self.result[i][1], self.res) for i in range(max(1, self.max_path/2)))
    def get_err3(self):
        return min(cmpres(self.result[i][1], self.res) for i in range(self.max_path))
        
    def get_score(self):
        return 100. - 50. * float(self.get_err()) / float(self.sub_size)
    def get_score2(self):
        return 100. - 50. * float(self.get_err2()) / float(self.sub_size)
    def get_score3(self):
        return 100. - 50. * float(self.get_err3()) / float(self.sub_size)
        
        
    def rstr(self): return '%.0f(%i)'%(self.get_score(), self.get_err())

class SDTW(DTW):
    name = 'SDTW' 
    def test(self, x, y, colok): 
        if x>0 :
            if colok : self.add_res(True,1, abs(self.sub_val(y)-self.main_val(x)) , x)
            
            self.add_res(False,1, self.main_val(x))
    
class DTW2(DTW): 
    name = 'DTW2'
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res(True,1, abs(self.sub_val(y)-self.main_val(x)) , x)
            
            self.add_res(False,1, self.main_val(x))
        if colok and x >2 :
                xv = self.main_val
                self.add_res(True,3, abs(self.sub_val(y)-xv(x)-xv(x-1)-xv(x-2)) , x-2, x)
    #abs(self.data1(x-1)+self.data1(x-2)+self.data1(x-3)
class DTW3(DTW): 
    name = 'DTW3'
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res(True,1, abs(self.sub_val(y)-self.main_val(x)) , x)
            
            self.add_res(False,1, self.main_val(x))
        if colok and x >2 :
                xv = self.main_val
                self.add_res(True,3, abs(self.sub_val(y)-xv(x)-2*xv(x-1)-xv(x-2)) , x-2, x)
class DTW4(DTW): 
    name = 'DTW4'
    diag_dist = 3 # cnx y-1, x-1 -2* n (n in [1 ..diag_dist] )
    del_cost = 1.
    del_cost_diag = .5
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res(True,1, abs(self.sub_val(y)-self.main_val(x)) , x)
            self.add_res(False,1, self.main_val(x) *self.del_cost)
        
        if colok :
            xv = self.main_val
            s1 = self.sub_val(y)-xv(x)
            s2 = 0
            
            for d in range(0, self.diag_dist) :
                if x > 2+2*d :
                    s1 -= xv(x-1-2*d) +  xv(x-2-2*d)
                    s2+= xv(x-1-2*d)
                    self.add_res(True,3+2*d, abs(s1)+self.del_cost_diag*s2 , x-2-2*d, x)
                    
                
            
        
#        if colok and x >2 :
#                xv = self.main_val
#                self.add_res(True,3, abs(self.sub_val(y)-xv(x)-xv(x-1)-xv(x-2))+.5*xv(x-1) , x-2, x)

class DTW4_1(DTW4): 
    name = 'DTW4.1'
    diag_dist = 1 # cnx y-1, x-1 -2* n (n in [1 ..diag_dist] )

class DTW4_2(DTW4): 
    name = 'DTW4.2'
    diag_dist =5
    #del_cost_diag = -0.43

class RDTW(DTW): 
    name = 'RDTW'
    subcost = 1. # cost for 2. ratio
    del3cost = 1.
    
    thres_fact = 0.
    def init(self):
        self.del_cost = sum(self.main[1:]) /2. /self.main_size
        self.main_t =  sum(self.main[1:]) *self.thres_fact /self.main_size
        self.sub_t =  sum(self.sub[1:]) *self.thres_fact /self.sub_size
        self.sub_mean= sum(self.sub[1:]) /self.sub_size
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res2(y, x)
            
            self.add_res(False,1, self.main_val(x)/self.del_cost)
        if colok and x >2 :
                xv = self.main_val
                self.add_res2(y, x, 3, self.del3cost *self.main_val(x-1)/self.del_cost )
                
    def add_res2(self,y,x1,dx=1,  add=0.):
        x2 = x1
        x1 = x1-dx+1
        vy = self.sub_val(y)
        r0 = sum(self.main_val(x) for x in xrange(x1, x2+1))/self.sub_val(y) *self.sub_val(y-1)
        pp = ((x1-1, x2-1), )
    #abs(self.data1(x-1)+self.data1(x-2)+self.data1(x-3)
        col = self.prev
        x = self.cur_x-dx
        
        for cc, path in self.prev[self.cur_x-dx] :
            if path :
                ox1, ox2 = path[-1]
                rr = r0  /sum(self.main_val(x) for x in xrange(ox1+1, ox2+2))
                if rr < 1. : rr = 1./rr
                rr -=1.
                rr *=math.log(1+vy/self.sub_mean)
                self.cur_res.append((cc+rr*self.subcost +add, path+pp))
            else :
                self.cur_res.append((cc+add, path+pp))
    def sub_val(self, x): return max(self.sub_t, self.sub[x])
    def main_val(self, x): return max(self.main_t, self.main[x])
        
class RDTW2(RDTW): 
    name = 'RDTW2'
        
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res2(y, x)
            
            self.add_res(False,1, self.main_val(x)/self.del_cost)
        if colok and x >2 :
                xv = self.main_val
                self.add_res2(y, x, 3, self.del3cost *self.main_val(x-1)/self.del_cost )
        if colok and x >4 :
                xv = self.main_val
                self.add_res2(y, x, 5, self.del3cost *(self.main_val(x-1)+self.main_val(x-3))/self.del_cost )
                
class RDTW3(RDTW): 
    name = 'RDTW3'
        
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res2(y, x)
            
            self.add_res(False,1, self.main_val(x)/self.del_cost)
        if colok and x >2 :
                xv = self.main_val
                self.add_res2(y, x, 3, self.del3cost *self.main_val(x-1)/self.del_cost )
        if colok and x >4 :
                xv = self.main_val
                self.add_res2(y, x, 5, self.del3cost *(self.main_val(x-1)+self.main_val(x-3))/self.del_cost )
        if colok and x >6 :
                xv = self.main_val
                self.add_res2(y, x, 7, self.del3cost *(self.main_val(x-1)+self.main_val(x-3)+self.main_val(x-5))/self.del_cost )
                
class R2DTW(DTW): 
    name = 'R2DTW'
    
    
    subcost = .8 # cost for 2. ratio
    
    del3cost = 1.
    
    max_subst = 5
    subst_dist = 5
    
    thres_fact = 0.
    
    class Options (DTW.Options):
        max_subst = None # 0+ ( a data sgment can match 1+2*n ref segment)
        subst_dist = None  # 1+ (distance to check ratio)
        del_factor  = None  # float (prob to insert gap)
        
        
    def set_options(self, opt):
        DTW.set_options(self, opt)
        if opt.max_subst is not None : 
            self.max_subst = opt.max_subst
        if opt.subst_dist is not None and self.subst_dist >=1: 
            self.subst_dist  = opt.subst_dist
        if opt.del_factor is not None :
            self.subcost *= math.exp(opt.del_factor)
            

    #subcost = .2
    #subst_dist = 10
    #subcost = 100.
    
    
    def init(self):
        self.del_cost = sum(self.main[1:]) /2. /self.main_size
        self.main_t =  sum(self.main[1:]) *self.thres_fact /self.main_size
        self.sub_t =  sum(self.sub[1:]) *self.thres_fact /self.sub_size
        self.sub_mean= sum(self.sub[1:]) /self.sub_size
        
        
    def test(self, x, y, colok): 
        
        if x>0 :
            if colok : 
                self.add_res2(y, x)
            
            self.add_res(False,1, self.main_val(x)/self.del_cost)
        
        if colok :
            xv = self.main_val
            for i in range(self.max_subst) :
                if x <= 2 +2*i: break
                self.add_res2(y, x, 3+2*i, self.del3cost *sum(xv(x-1-2*j) for j in range(i+1))/self.del_cost )
                
    def add_res2(self,y,x1,dx=1,  add=0.):
        xv = self.main_val
        x2 = x1
        x1 = x1-dx+1
        vy = self.sub_val(y)
        r0 = sum(self.main_val(x) for x in xrange(x1, x2+1))/self.sub_val(y) *self.sub_val(y-1)
        cur_r = sum(xv(x)  for x in xrange(x1, x2+1)) / self.sub_val(y)
        
        pp = ((x1-1, x2-1, cur_r), )
        
        col = self.prev
        x = self.cur_x-dx
        
        for cc, path in self.prev[self.cur_x-dx] :
            if path :
                cost =0
                for i in range(min(len(path), self.subst_dist)) :
                    cost += self.ratio_comp(cur_r,path[-(i+1)][2]) * self.comp_weight(y, y-1-i)
                self.cur_res.append((cc+cost*self.subcost +add, path+pp))
            else :
                self.cur_res.append((cc+add, path+pp))
                
                
    def sub_val(self, x): return max(self.sub_t, self.sub[x])
    def main_val(self, x): return max(self.main_t, self.main[x])
    
    def ratio_comp(self, r1, r2):
        v = r1/r2
        if v <1. :v = 1./v
        return v-1
        
        
    def comp_weight(self, y, y2):
        return math.log(1 + min(self.sub_val(y),self.sub_val(y2)) /self.sub_mean) #/(y-y2)
        
                
    def execute(self):
        DTW.execute(self)
        self.result = [
            (c, [ i[:2] for i in p]
            ) for c, p in self.result
        ]
    
    
all_engines = (SDTW, DTW2, DTW3, DTW4, RDTW, RDTW2, RDTW3)


class Corpus :
    dtw = (
        SDTW, 
        DTW2, 
        DTW3, 
        DTW4, 
        RDTW, 
        RDTW2, 
        RDTW3, 
    )


    class Stats:
        def __init__(self, dtw):
            self.dtw = dtw
            self._nbr =0
            self._tot = 0.
            self._tot2 =0.
            self._tot3 =0.
            self._min_score = 100.
            self._min_score2 = 100.
            self._min_score3 = 100.
            self._stats = list()
            
        def add(self,score, score2=0., score3=0. ):
            if score < self._min_score :self._min_score = score
            if score2 < self._min_score2 :self._min_score2 = score2
            if score3 < self._min_score3 :self._min_score3 = score3
            self._tot += score
            self._tot2 += score2
            self._tot3 += score3
            self._stats.append((score, score2, score3, self._nbr))
            self._nbr += 1
    
        def dump(self):
            print '%s : %.1f (min:%.0f) 2:%.1f (min:%.0f) 3:%.1f (min:%.0f)'%(self.dtw.name, (self._tot /self._nbr), self._min_score, (self._tot2 /self._nbr), self._min_score2, (self._tot3 /self._nbr), self._min_score3)
            self._stats.sort()
            for i, j, k, l in self._stats[:10] :
                print '%i:%.0f/%.0f/%.0f'%(l, i, j, k), 
            print

    def __init__(self, corpus):
        self.run(corpus)
        
        
    def run(self, corpus):
        main = None
        self.stats  = map(self.Stats, self.dtw )

        n =0
        for i in open(corpus) :
            if not main :
                main = eval(i)
                continue
                
            sub, res = eval(i)
            print n, ':', 
            self.do_test(main,sub, res )
            n+=1
            
        for i in self.stats :
            i.dump()
    def do_test(self, main, sub, res): 
        for i in self.stats :
            v = i.dtw(main, sub, res)
            print v.rstr(), 
            i.add(v.get_score(), v.get_score2(), v.get_score3())
        print
        
            
        
#Corpus('crps1')


def read_test(file_name):
    f = open(file_name)
    ref = eval(f.readline())
    ter, res = eval(f.readline())
    f.close()
    return ref, ter, res



#================== In/out file ===========================
class InFile :
    
    def __init__(self, filename, opt=""):
        self._error = None
        self._filename = filename
        try : 
            x = open(filename)
        except :
            self._error =  "ERROR: Can't open file '%s'"%filename
            return 
            
        firstline = True
        normal = True
        
        res = []
        for i in x :
            i = i.strip()
            
            if not i : continue
            if i[0] == '#' : continue
            if firstline :
                t = i.split()[0].lower()
                if t in ('t','true', '1', 'n', 'normal') : normal = True
                elif  t in ('f','false', '0', 'r', 'reverse') : normal = False
                
                else :
                    self._error = "ERROR: bad orientation must be 1,0,n,r,normal,reverse,true,false,t or f"
                    return 
                    
                firstline = False
                continue
                
            values = i.split(None, 1)
           
            if len(values) == 1 :
                vv = values[0]
                desc = ""
            else : 
                vv, desc = values
                desc = desc.strip()
            try : 
                vv = float(vv)
            except :
                self._error = "ERROR: bad value '%s'"%vv
                return 
            res.append((vv, desc))
            
        if not  res :
            self._error =  "ERROR: no values"
            return 
            
        lgr = []
        desc = []
        prof = []
        if True : #profondeur
            pi, pj = res[0]
            prof .append(pi)
            for i, j in res[1:] :
                lgr.append(i-pi)
                desc.append(pj)
                prof.append(i)
                pi = i
                pj=j
        else :
            prof.append(0.)
            pp = 0.
            for i, j in res :
                lgr.append(i)
                desc.append(j)
                pp+= i
                prof.append(pp)
            desc.append('')
                
        self._error = None
        
        self.normal = normal
        self.lgr = tuple(lgr)
        self.desc = tuple(desc)
        self.prof = tuple(prof)
        
        self.check()


    def todtw(self): return (self.normal, )+self.lgr

    def ok(self): return not self._error
    def error(self): return  self._error
        
    def check(self):
        for n, v in enumerate(self.lgr) :
            if v < 1e-20 :
                self._error = 'Segment #%i (%s) is too small'%(n,self.desc[n] or 'NO DESC')
                return False
        if len(self.lgr)< 5 :
                self._error = 'Not enough segments'
                return False
            
        return self.ok()
            
    def filename(self): return self._filename
    
    
# output stuct
class OutRes(object):
    def __init__(self, ref, data, res):
        self.ref = ref
        self.data = data
        self.res  = res
        self._error = None
    
    def _gnbr(self, n):
        if n <= 0 : return self.nbr_results()
        return max(n, self.nbr_results())
        
        
    def _cfile(self, filename):
        try : 
            f = open(filename, "wt")
            self._error = None
            return f
        except :
            self._error =  "ERROR: Can't create file '%s'"%filename
            return None

    def ok(self): return not self._error
    def error(self): return  self._error

    
    
    def nbr_results(self):
        return len(self.res)

    def write_all(self, filename, opt=''):
        pass
        
     
    #res name2
    def res_name2(self, res, txt):
        cost, rr = res
        return u'%sC:%g ( %s - %s )'%(
            txt, 
            cost,
            self.ref.desc[rr[0][0]],  self.ref.desc[rr[-1][1]]
        )
        
    # result name 
    def res_name(self, n):
        return self.res_name2(self.res[n], u"#%i "%n)
    def all_res_name(self, nbr = 0):
        return [ self.res_name(i) for i in range(self._gnbr(nbr))  ]
    
    # result digest
    def write_digest(self, filename, opt='', nbr = 0):
        f = self._cfile(filename)
        if not f : return False
        nbr = self._gnbr(nbr)
        if True :
            f.write('#Ref:%s\n#Data:%s\n'%(self.ref.filename(),self.data.filename()))
            f.write('\n')
        for n in range(nbr) :
            c, rr = self.res[n]
            f.write('%02i\t%5f\t%s - %s\n'%(
                    n, 
                    c, 
                    self.ref.desc[rr[0][0]], 
                    self.ref.desc[rr[-1][1]], 
                    
                ))
        f.close()
        return True
        
    #write a result
    def write_result(self, filename, num, opt=''):
        f = self._cfile(filename)
        if not f : return False
        c, r = self.res[num]
        
        if True :
            f.write('#Data: %s\n'%self.data.filename())
            f.write('#Ref: %s\n'%self.ref.filename())
            f.write('#result %i, cost=%g, range:%s-%s\n'%(
                            num, c,  self.ref.desc[r[0][0]],  self.ref.desc[r[-1][1]]))
            f.write('\n')
        if True : #by in items/prof
            for ni, start_end in enumerate(r) :
                start, end = start_end
                oline = []
                oline.append('%g'%self.ref.prof[start])
                oline.append('%g'%self.data.prof[ni])
                if start == end :
                    oline.append(self.ref.desc[start])
                else :
                    oline.append(self.ref.desc[start]+' to '+self.ref.desc[end])
                oline.append(self.data.desc[ni])
                        
                f.write('\t'.join(oline)+'\n')
            if True :
                ref  = r[-1][1]+1
                tgt = len(r)
                oline = []
                oline.append('%g'%self.ref.prof[ref])
                oline.append('%g'%self.data.prof[tgt])
                if len(self.ref.desc)<=ref : oline.append('===END==')
                else :oline.append(self.ref.desc[ref])
                oline.append('===END==')
                f.write('\t'.join(oline)+'\n')
                    
        f.close()
        return True
        
    #write all results
    def all_write_result(self, filename, opt='', nbr=0,fileext='.txt' ):
        if True :
            if not self.write_digest(filename+'.all'+fileext, opt=opt, nbr=nbr) : return False
            
        if True :
            for i in range(self._gnbr(nbr)) :
                if not self.write_result("%s.%03i%s"%(filename,i,fileext), i, opt) : return False
        return True
    
