"""
             ***TEST***
    ***CORRELATION FUNCTIONS***

    Simulations for SU(2) gauge Higgs
        - Single higgs field
        - Gradient flow
        - correlation function
"""

import Pkg
Pkg.activate("/home/gtelo/PhD/LatGPU/SU2proj/code/runs_su2higgs/latticegpu.jl")
Pkg.status()


using CUDA, Logging, StructArrays, Random, TimerOutputs, ADerrors, Printf, BDIO,LatticeGPU
CUDA.versioninfo()
CUDA.allowscalar(false)

# Set lattice/block size
ntwist = (0,0,0,0,0,0)
lp = SpaceParm{4}((8,8,8,16), (4,4,4,4), BC_PERIODIC, ntwist)
println("Space  Parameters: ", lp)

# Seed RNG
println("Seeding CURAND...")
Random.seed!(CURAND.default_rng(), 1234)
Random.seed!(1234)

# Set group and precision
GRP  = SU2
ALG  = SU2alg
SCL  = SU2fund
PREC = Float64
NSC = 1
println("Precision: ", PREC)


println("Allocating YM workspace")
ymws = YMworkspace(GRP, PREC, lp)

# Main program
println("Allocating gauge field")
U = vector_field(GRP{PREC}, lp)
fill!(U, one(GRP{PREC}))

println("Time to take the configuration to memory: ")
@time Ucpu = Array(U)

# Set gauge parameters
# FIRST SET: Wilson action/flow
println("\n## WILSON ACTION/FLOW TIMES")

# Scalar parameters
k1 = 0.268
k1 = tryparse(Float64, ARGS[1])
beta = 2.4
beta = tryparse(Float64, ARGS[2])
eta = 0.5
eta = tryparse(Float64, ARGS[3])
sp = ScalarParm((k1,), (eta,))
println("Scalar parameters: ", sp)

println("Allocating Scalar workspace")
sws  = ScalarWorkspace(PREC, NSC, lp)
# Scalar fields φ=0
println("Allocating scalar field")
# Phi includes all scalar fields
Phi = nscalar_field(SCL{PREC}, NSC, lp)
# starting value
fill!(Phi, zero(SCL{PREC}))

# MC SIMULATION PARAMETERS ##################################################
dt  = 0.05
nsteps  = 60
# MD integrator
int = omf4(PREC, dt, nsteps)
println(int)

nth     = 100 #thermalization length
niter   = 1000 # MC length

# Global Observables
pl      = Vector{Float64}(undef, niter+nth)
rho1    = Vector{Float64}(undef, niter+nth)
Lphi1   = Vector{Float64}(undef, niter+nth)
Lalp1   = Vector{Float64}(undef, niter+nth)
#correlators
#Higgs
h2      = Array{Float64,3}(undef,NSC, lp.iL[end], niter)
#W-Boson - 2 fields; 3 spatial directions; 3 pauli matrices
w1      = Array{Float64,5}(undef,NSC,3,3, lp.iL[end], niter)
#hmc
dh      = Vector{Float64}(undef, niter+nth)
acc     = Vector{Bool}(undef, niter+nth)

# FLOW PARAMETERS
Uflw = vector_field(GRP{PREC}, lp)
flw_steps = 0 # GF total flow calls per MC element
flw_s     = 10 # steps of integration for each call
flw_int   = 10 # MC gap between each flow call
flw_dt    = 0.01 # fixed step size flow
gfinfo = string("_GF_stps",flw_steps,"_s",flw_s,"_int",flw_int,"_dt",flw_dt)
flwint = wfl_rk3(PREC, flw_dt, 1.0E-6)
println(flwint)

flw_iter  = convert(Int64, niter/flw_int) #total number of flow measurements

# flow
flwtime = Vector{Float64}(undef, flw_steps)
flwtime = Vector{Float64}(undef, flw_steps)
for f in 1:flw_steps
    flwtime[f] = f*flw_s*flw_dt
end
E       = Array{Float64, 2}(undef, flw_iter, flw_steps)
dE      = Array{Float64, 2}(undef, flw_iter, flw_steps)
Ecl     = Array{Float64, 2}(undef, flw_iter, flw_steps)
dEcl    = Array{Float64, 2}(undef, flw_iter, flw_steps)

# Correlation function - Smearing
sm = true
sus = 50 #smearing steps
sdt = flw_dt # step size - GF smearinga
sss = 20 # smearing steps
srs = 0.1 # r_smear
smear = smr{PREC}(sus, sdt, sss, srs)
Phismear = nscalar_field(SCL{PREC}, NSC, lp)
sminfo = string("_SM_us_",sus,"_ss_",sss)

# Save observables in BDIO file
# Uinfo 1: Simulation parameters Int64 - [nth, niter,flw_s,flw_steps,flw_int,flw_iter]
# Uinfo 2: Simulations parameters Float64 - [beta,flw_dt,k2,et1,et2,mu,xi1,xi2,xi3,xi4]
# Uinfo 3: Measurements plaquette
# Uinfo 4: Measurements rho1 & rho2 (two different records)
# Uinfo 5: Measurements Lphi1 & Lphi2
# Uinfo 6: Measurements Lalp1 & Lalp2
# Uinfo 7: hash
# Uinfo 8: Measurement dh
# Uinfo 9: Flow Measurements time
# Uinfo 10: Measurement Flow E_plaq
# Uinfo 11: Measurement Flow E_plaq
# Uinfo 12: Measurement Flow E_clv
# Uinfo 13: Measurement Flow dE_clv
# Uinfo 13:
# Uinfo 14:

#TODO add explanation of bdio files in bdio file

# Lattice dimensions
global dm = ""
for i in 1:lp.ndim-1
    global dm *= string(lp.iL[i])*"x"
end
dm *= string(lp.iL[end])

sclr_s = "_k$(k1)_eta$(eta)"
#################################### SIMULATION ###################################
#################################### Simulate different beta values

for sim in 1:1
    gp = GaugeParm{PREC}(GRP{PREC}, beta, 1.0)
    println("Gauge  Parameters: ", gp)

    filename = string("simulations/test_higgs",NSC,"_",dm,"_beta",beta,sclr_s,"_niter", niter,"_eps",dt,"_nsteps",nsteps,gfinfo,sminfo,".bdio")
    fb = BDIO_open(filename, "d",
                "$(NSC) Scalar Higgs Simulations - Gradient Flow for gauge fields; correlation functions")
    # Simulation param
    iv = [nth, niter,flw_s,flw_steps,flw_int,flw_iter, sus, sss]
    BDIO_start_record!(fb, BDIO_BIN_INT64LE, 1, true)
    BDIO_write!(fb, iv)
    BDIO_write_hash!(fb)

    BDIO_start_record!(fb, BDIO_BIN_F64LE, 2, true)
    s_parms = [beta,flw_dt,k1,eta,sdt, srs]
    BDIO_write!(fb, s_parms)
    BDIO_write_hash!(fb)

    k = 0
    for i in 1:1
        HMC!(U,Phi, int,lp, gp, sp, ymws, sws; noacc=true)
        # Thermalization
        for j in 1:nth
            k += 1
            dh[k], acc[k] = HMC!(U,Phi,int,lp, gp, sp, ymws, sws)
            pl[k]  = plaquette(U,lp, gp, ymws)
            # observables φ1
            rho1[k],Lphi1[k],Lalp1[k] = scalar_obs(U, Phi, 1, sp, lp, ymws)
            @printf("  THM %d/%d (beta: %4.3f):   %s   %6.2e    %20.12e    %20.12e    %20.12e    %20.12e\n",
                j, nth, beta, acc[k] ? "true " : "false", dh[k],
                pl[k], rho1[k], Lphi1[k], Lalp1[k])

        end
        println(" ")

        # MC chain
        kflw=0
        jflw=0
        for j in 1:niter
            k += 1
            dh[k], acc[k] = HMC!(U,Phi, int,lp, gp, sp, ymws, sws)
            pl[k]  = plaquette(U,lp, gp, ymws)
            rho1[k],Lphi1[k],Lalp1[k] = scalar_obs(U, Phi,1, sp, lp, ymws)
            if sm
                Uflw .= U
                Phismear .= Phi
                h2[1,:,j], w1[1,:,:,:,j]      = scalar_corr(Uflw, Phismear, 1, smear, sp, lp, ymws, gp, sws)
            else
                h2[1,:,j], w1[1,:,:,:,j]      = scalar_corr(U, Phi, 1, sp, lp, ymws, gp, sws)
            end

            @printf("  MSM %d/%d (beta: %4.3f):   %s   %6.2e    %20.12e    %20.12e    %20.12e    %20.12e\n",
                j, niter, beta, acc[k] ? "true " : "false", dh[k],
                pl[k], rho1[k], Lphi1[k], Lalp1[k])

            # flow every 'flw_int'
            kflw += 1
            if (kflw == flw_int)
                print("\n\t## START Flow:")
                jflw+=1
                Uflw .= U
                for f in 1:flw_steps
                    #step flows
                    # wfl_rk3(Uflw,flw_s,flw_dt,lp,ymws)
                    flw(Uflw, flwint, flw_s, gp, lp, ymws)
                    # E(f)
                    eoft = Eoft_plaq(Uflw, gp, lp, ymws)
                    E[jflw,f] = eoft
                    eoft = Eoft_clover(Uflw, gp, lp, ymws)
                    Ecl[jflw,f] = eoft

                end
                kflw = 0
                println("\n\t## END flow measurements")
            end
        end
    end

    ################################### SAVE TO FILE ##################################
    # Write to BDIO
    # Plaquette
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 3, true)
    BDIO_write!(fb, pl[begin:end])

    # Rho1
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 4, true)
    BDIO_write!(fb, rho1[begin:end])

    # Lphi1
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 5, true)
    BDIO_write!(fb, Lphi1[begin:end])

    # Lalp1
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 6, true)
    BDIO_write!(fb, Lalp1[begin:end])

    #Energy conservation
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 8, true)
    BDIO_write!(fb, dh[begin:end])
    BDIO_write_hash!(fb)

    # FLOW OBSERVABLES
    # Flow time
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 9, true)
    BDIO_write!(fb, flwtime)
    BDIO_write_hash!(fb)
    # E Plaquette
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 10, true)
    BDIO_write!(fb, E)
    BDIO_write_hash!(fb)
    # E Clover
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 11, true)
    BDIO_write!(fb, Ecl)
    BDIO_write_hash!(fb)

    #CORRELATION FUNCTIONS
    #Higgs interpolator
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 12, true)
    for i in 1:niter
        BDIO_write!(fb, h2[1,:,i])
    end
    BDIO_write_hash!(fb)
    #W-boson interpolator
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 13, true)
    for k in 1:niter
        for mu in 1:3
            for i in 1:3
                BDIO_write!(fb, (w1[1,mu,i,:,k]))
            end
        end
    end
    BDIO_write_hash!(fb)

    BDIO_close!(fb)
end

println("\n\n")

println("## Timming results")
print_timer(linechars = :ascii)


println("## END")
# END
