import numpy
import scipy
import pylab
from math import exp,sqrt,pi
from exceptions import Exception
scipy.pkgload('interpolate')
scipy.pkgload('integrate')
scipy.pkgload('optimize')

class err(Exception):
    pass

class InvertError(err):
    pass

class LeftError(err):
    pass

class RightError(err):
    pass

def load(filename,column_1,column_2):
    """
    Load from the ascii file filename, returning two columns
    """
    data=pylab.load(filename)
    x,y=data[:,column_1-1],data[:,column_2-1]
    return x,y

def interp(x,y):
    """Return a function which interpolates linearly between points on the curve"""
    return scipy.interpolate.interp1d(x,y)

def floor_index(array,value):
    """
    Find the largest index with array[index]<value
    """
    if value < array[0] or value>array[-1]:
        raise ValueError("Value not in array range")
    return max(numpy.where(numpy.array(array)<=value))

def tab_area(x,y,x0,x1,interpolator=None):
    """
    Find the area under the curve y(x) between x0 and x1 when
    x and y are tabulated.  This is quicker and more accurate than defining
    an interpolation function (since this function knows about the kinks in y).
    If you already have a linear interpolator defined you can supply it as the
    interpolator argument.

    If you have a proper function from which y are just samples, consider using
    'area' for better accuracy.
    x should be monotonic increasing.
    """
    if interpolator is None: interpolator=interp(x,y)
    low_index=floor_index(x,x0)
    high_index=floor_index(x,x1)
    new_x=numpy.concatenate(([x0],x[low_index+1:high_index+1],[x1]))
    new_y=numpy.concatenate((interpolator(x0),y[low_index+1:high_index+1],interpolator(x1)))
    return scipy.integrate.trapz(new_y,new_x)


def area(func,low,high):
    """Integrate the function func between x=low and x=high using scipy quadrature integration routine"""
    value,error=scipy.integrate.quad(func,low,high,limit=100)
    return value

def inverse(f,target,lower,upper,maxlike,findleft=True,findright=True):
    """Use a bisection algorith to find the two inverses of the value target of the function x
    ie find the two x such that f(x)=target, one to the left of the maxlike peak and one to the right
    in the overall range (lower,upper)
    """
    rightfail,leftfail=False,False
    if findleft:
        try:
            left=scipy.optimize.bisect(lambda x: f(x)-target,lower,maxlike,maxiter=500)
        except ValueError:
            print "Inversion Failure on left - does your distribution cut off before the symmetric level?"
            print "DATA:"
            print "target:",target
            print "Lower,f(Lower):", lower,f(lower)
            print "Upper,f(Upper):",upper,f(upper)
            print "maximum,f(maximum):",maxlike,f(maxlike)
            leftfail=True
    if findright:
        try:
            right=scipy.optimize.bisect(lambda x: f(x)-target,maxlike,upper,maxiter=500)
        except ValueError:
            print "Inversion Failure on right - does your distribution cut off before the symmetric level?"
            print "DATA:"
            print "target:",target
            print "Lower,f(Lower):", lower,f(lower)
            print "Upper,f(Upper):",upper,f(upper)
            print "maximum,f(maximum):",maxlike,f(maxlike)
            rightfail=True
    if leftfail and rightfail: raise BothError
    if leftfail: raise LeftError
    if rightfail: raise RightError
    if findleft and findright: return left,right
    if findleft : return left
    if findright: return right

def onetail(x,y,stepfactor=100,plotting=False):
    """
    Find the lower one-tailed confidence limits.
    """
    index_of_maximum=y.argmax()
    maximum_value=y.max()
    xrange=x[-1]-x[0]
    location_of_maximum=x[index_of_maximum]
    f=interp(x,y)
    total_area=tab_area(x,y,x[0],x[-1],interpolator=f)
    current_area=0.0
    current_level=maximum_value
    step_size=maximum_value/stepfactor
    left=x[0]
    right=x[0]
    while current_area/total_area <= 0.05:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    lower2=right
    while current_area/total_area <= 0.34:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    lower1=right
    while current_area/total_area <= 0.68:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    upper1=right
    while current_area/total_area <= 0.95:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    upper2=right
    if plotting:
        pylab.plot(x,y)
        pylab.plot([lower2,lower2],[0.0,f(lower2)],'r')
        pylab.plot([lower1,lower1],[0.0,f(lower1)],'g')
        pylab.plot([upper2,upper2],[0.0,f(upper2)],'r')
        pylab.plot([upper1,upper1],[0.0,f(upper1)],'g')
    return lower2,lower1,location_of_maximum,upper1,upper2
    



def maxlike(x,y,stepfactor=100,plotting=False):
    """Find the maximum likelihood 68% and 95% confidence levels and return them
    and the peak, in a maximum likelihood kinda way.
    """
    index_of_maximum=y.argmax()
    maximum_value=y.max()
    location_of_maximum=x[index_of_maximum]
    f=interp(x,y)
    total_area=tab_area(x,y,x[0],x[-1],interpolator=f)
    current_area=0.0
    current_level=maximum_value
    step_size=maximum_value/stepfactor
    while current_area/total_area <= 0.68:
        current_level-=step_size
        left,right=inverse(f,current_level,x[0],x[-1],location_of_maximum)
        current_area=tab_area(x,y,left,right,interpolator=f)
        lower1,upper1=left,right
    while current_area/total_area <= 0.95:
        current_level-=step_size
        try:
            left,right=inverse(f,current_level,x[0],x[-1],location_of_maximum)
        except LeftError:
            left=x[0]
            right=inverse(f,current_level,x[0],x[-1],location_of_maximum,findleft=False)
        current_area=tab_area(x,y,left,right,interpolator=f)
    lower2,upper2=left,right
        
    if plotting:
        pylab.plot(x,y)
        pylab.plot([lower2,lower2],[0.0,f(lower2)],'r')
        pylab.plot([lower1,lower1],[0.0,f(lower1)],'g')
        pylab.plot([upper2,upper2],[0.0,f(lower2)],'r')
        pylab.plot([upper1,upper1],[0.0,f(upper1)],'g')
    return lower2,lower1,location_of_maximum,upper1,upper2
    
def postlike(x,y,stepfactor=100,plotting=False):
    """Find the maximum likelihood 68% and 95% confidence levels and return them
    and the peak
    """
    index_of_maximum=y.argmax()
    maximum_value=y.max()
    xrange=x[-1]-x[0]
    location_of_maximum=x[index_of_maximum]
    f=interp(x,y)
    total_area=tab_area(x,y,x[0],x[-1],interpolator=f)
    current_area=0.0
    current_level=maximum_value
    step_size=maximum_value/stepfactor
    left=x[0]
    right=x[0]
    while current_area/total_area <= 0.025:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    lower2=right
    while current_area/total_area <= 0.16:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    lower1=right
    while current_area/total_area <= 0.84:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    upper1=right
    while current_area/total_area <= 0.975:
        right+=xrange/stepfactor
        current_area=tab_area(x,y,left,right,interpolator=f)
    upper2=right
    if plotting:
        pylab.plot(x,y)
        pylab.plot([lower2,lower2],[0.0,f(lower2)],'r')
        pylab.plot([lower1,lower1],[0.0,f(lower1)],'g')
        pylab.plot([upper2,upper2],[0.0,f(upper2)],'r')
        pylab.plot([upper1,upper1],[0.0,f(upper1)],'g')
    x2,y2=normalize(x,y)
    mean_value=tab_area(x,x*y,x[0],x[-1])
    return lower2,lower1,mean_value,upper1,upper2


def normalize(x,y):
    a=tab_area(x,y,x[0],x[-1])
    return x,y/a
   
def calibration(x,y,w,steps=100):
    """
    Marginalize numerically over the calibration uncertainty.  w is the gaussian width of the calibration uncertainty around 1.0
    """
    p=interp(x,y)
    N=lambda t: exp(-(t-1.0)**2/(w*w))/(sqrt(2*pi)*w)
    def integrand(f,X):
        f2=f*f
        if X*f2 < x[0] or X*f2 > x[-1]:
            return 0
        term1=p(X*f2)
        term2=N(f)
        return term1*term2
    x2=numpy.arange(1.5*x[0]-0.5*x[-1],1.5*x[-1]-0.5*x[0],2.0*(x[-1]-x[0])/steps)
    y2=numpy.empty(len(x2),dtype=float)
    for i in range(len(x2)):
        y2[i]=scipy.integrate.quad(integrand,0.5,1.5,args=x2[i],limit=200)[0]
    return x2,y2


