import os
import sys
import numpy as np
import matplotlib.pyplot as plt

import larch
from larch import Interpreter, Group
from larch.fitting import guess, minimize

from larch_plugins.math.lineshapes import gaussian
from larch_plugins.xafs import pre_edge, mback

import pandas as pd


def rdin(filename):
    scandata_f = pd.read_csv(filename, sep='\t', skiprows=12)
    if not ("Counter 0" in scandata_f.columns):
        scandata_f = pd.read_csv(filename, sep='\t', skiprows=8)  # TrajScan files need 8 header lines somehow?
    print scandata_f.columns

    if not ("Counter 0" in scandata_f.columns):
        print ("Problem with header. skipping 12 or 10 lines did not make it. Check input file.")
        return None
    return scandata_f



def prepare_scan(scandata_f, datacounter="Counter 3", reference_counter='Counter 2'):
    # Preparing Scan (normalization)
    if 'Counter 4' in scandata_f.columns:
        clockname = 'Counter 4'
    elif 'Counter 6' in scandata_f.columns:
        clockname = 'Counter 6'
    else:
        print("No counter for clock found (looked for 'Counter 4' and 'Counter 6'). Defaulting to 'Counter 0'.")
        clockname = 'Counter 0'

    scandata_f["I_Norm0"] = scandata_f[datacounter].astype(float) / scandata_f[reference_counter].astype(float)
    scandata_f["I_Normt"] = scandata_f[datacounter].astype(float) / scandata_f[clockname].astype(float)
    scandata_f["Energy"] = scandata_f["Energy"].round(1)
    # scandata_f["Z"] = scandata_f["Z"].round(2)
    return scandata_f

scandata_f = rdin("TrajScan21930_swf.txt")
prepare_scan(scandata_f)


mylarch = Interpreter()

mdat = scandata_f
mdat.x = mdat.Energy
mdat.y = mdat.I_Norm0

# create a group of fit parameters
params = Group(off=guess(0),
                amp=guess(5, min=0),
                cen=guess(535),
                wid=guess(1, min=0))   # Need an edge value for the first guess for the center of the
# guassian: 'cen='. Can use the edge-finder for that

init = params.off + params.amp *  gaussian(mdat.x, params.cen, params.wid)


# define objective function for fit residual
def resid(p, data):
    return data.y - (p.off + p.amp * gaussian(data.x, p.cen, p.wid))


# perform fit
minimize(resid, params, args=(mdat,), _larch=mylarch)

# make final array
final = params.off + params.amp * \
                     gaussian(mdat.x, params.cen, params.wid)

# plot results
plt.plot(mdat.x, mdat.y)
plt.plot(mdat.x, init)
plt.plot(mdat.x, final)
plt.show()
