"""
    Simulations with 2 scalar doublets
    Fixed k_2, explore k_1 parameter space
"""


using CUDA, Logging, StructArrays, Random, TimerOutputs, ADerrors, Printf, BDIO


CUDA.allowscalar(false)
import Pkg
Pkg.activate("/home/gtelo/PhD/SU2proj/code/runs_su2higgs/latticegpu-su2-higgs")
using LatticeGPU

lp = SpaceParm{4}((8,8,8,8), (4,4,4,4))
beta = 2.25
# command line β
#beta = tryparse(Float64, ARGS[1])
gp = GaugeParm(beta, 1.0, (0.0,0.0), 2)

k2 = 0.0
# command line k2
#k2 = tryparse(Float64, ARGS[2])
# et1 = et2 = 0.5
et1 = 0.01
et2 = 0.01
xi1 = xi2 = xi3 = xi4 = 0.1
mu = 0.2

# NSC = tryparse(Int64, ARGS[1])
NSC = 2  # nr of scalars
println("Space  Parameters: \n", lp)
println("Gauge  Parameters: \n", gp)
GRP  = SU2
ALG  = SU2alg
SCL  = SU2fund
PREC = Float64
println("Precision:         ", PREC)

println("Allocating YM workspace")
ymws = YMworkspace(GRP, PREC, lp)
println("Allocating Scalar workspace")
sws  = ScalarWorkspace(PREC, NSC, lp)

# Seed RNG
println("Seeding CURAND...")
Random.seed!(CURAND.default_rng(), 1234)
Random.seed!(1234)

# Main program
# Gauge fields U_μ=1
println("Allocating gauge field")
U = vector_field(GRP{PREC}, lp)
fill!(U, one(GRP{PREC}))
# Scalar fields φ=0
println("Allocating scalar field")
# Phi includes all scalar fields
Phi = nscalar_field(SCL{PREC}, NSC, lp)
# starting value
fill!(Phi, zero(SCL{PREC}))

dt  = 0.005
nsteps  = 250

#histeresis run?
hist = false

nth    = 50 #thermalization length
niter  = 100 # MC length
nrks   = 30 # nr of k values
h      = (0.3/nrks)*2
# Observables
pl      = Vector{Float64}(undef, niter+nth)
rho1    = Vector{Float64}(undef, niter+nth)
Lphi1   = Vector{Float64}(undef, niter+nth)
Lalp1   = Vector{Float64}(undef, niter+nth)
rho2    = Vector{Float64}(undef, niter+nth)
Lphi2   = Vector{Float64}(undef, niter+nth)
Lalp2   = Vector{Float64}(undef, niter+nth)
dh      = Vector{Float64}(undef, niter+nth)
acc     = Vector{Bool}(undef, niter+nth)

# Save observables in BDIO file
# Uinfo 1:  Simulation parameters
# Uinfo 2:  List of kappa values
# Uinfo 3:  Thermalization plaquette
# Uinfo 4:  Measurements   plaquette
# Uinfo 5:  Thermalization rho2
# Uinfo 6:  Measurements   rho2
# Uinfo 8:  Thermalization Lphi
# Uinfo 7:  hash
# Uinfo 9:  Measurements   Lphi
# Uinfo 10:  Thermalization Lalp
# Uinfo 11: Measurements   Lalp
# Uinfo 12: Thermalization dh
# Uinfo 13: Measurement    dh

# Lattice dimensions
global dm = ""
for i in 1:lp.ndim-1
    global dm *= string(lp.iL[i])*"x"
end
dm *= string(lp.iL[end])
# scalar parameters
sclr_s = "_k2$(k2)_etas$(et1)_$(et2)_mu$(mu)_xis$(xi1)_$(xi2)_$(xi3)_$(xi4)"
if hist
    global filename = string("var_k/scalar",NSC,"_",dm,"_hist_slow_vark",nrks,"_beta",beta,sclr_s,"_niter", niter,"_eps",dt,"_nsteps",nsteps,".bdio")
else
    global filename = string("var_k/scalar",NSC,"_",dm,"_vark",nrks,"_beta",beta,sclr_s,"_niter", niter,"_eps",dt,"_nsteps",nsteps,".bdio")
end
fb = BDIO_open(filename, "d",
               "Scalar simulations with $NSC scalar fields")

# Simulation param
iv = [nth, niter, nrks]
BDIO_start_record!(fb, BDIO_BIN_INT64LE, 1, true)
BDIO_write!(fb, iv)
BDIO_write_hash!(fb)
# write scalar couplings
# k values
kstart = 0.0
kv = Vector{Float64}()
for i in 1:nrks
    push!(kv, h*(i-1)+kstart)
end
BDIO_start_record!(fb, BDIO_BIN_F64LE, 2, true)
BDIO_write!(fb, kv)
s_parms = [k2,et1,et2,mu,xi1,xi2,xi3,xi4]
BDIO_write!(fb, s_parms)
BDIO_write_hash!(fb)

# repeat k1 scan backwards if hysteresis
if hist
    ks = 2*nrks
else
    ks = nrks
end

for i in 1:ks
    if NSC == 1
        sp = ScalarParm((h*(i-1),), (eta,))
        # sp = ScalarParm((0.1,), (eta,))
    elseif NSC == 2
        if i>nrks
            sp = ScalarParm((h*(2*nrks-i)+kstart,k2), (et1,et2), mu, (xi1,xi2,xi3,xi4))
        else
            sp = ScalarParm((h*(i-1)+kstart,k2), (et1,et2), mu, (xi1,xi2,xi3,xi4))
            # sp = ScalarParm((0.4,k2), (et1,et2), mu, (xi1,xi2,xi3,xi4))
        end
    end
    println("## Simulating Scalar parameters: ")
    println(sp)
    
    k = 0
    # why this noacc=true?
    HMC!(U,Phi, dt,nsteps,lp, gp, sp, ymws, sws; noacc=true)

    # extra thermalization for histeresis - start/end points
    if hist && (i==1 || i==nrks)
        for j in 1:50
        DH, ACC = HMC!(U,Phi, dt,nsteps,lp, gp, sp, ymws, sws)
        PL  = plaquette(U,lp, gp, ymws)
        # observables φ1
        RHO1,LPHI1,LALP1 = scalar_obs(U, Phi, 1, sp, lp, ymws)
        # observables φ2
        RHO2,LPHI2,LALP2 = scalar_obs(U, Phi, 2, sp, lp, ymws)

        @printf("  Extra THM %d/%d (kappa: %4.3f):   %s   %6.2e    %20.12e    %20.12e    %20.12e    %20.12e   %20.12e    %20.12e    %20.12e\n",
                j, nth, sp.kap[1], ACC ? "true " : "false", DH,
                PL, RHO1, LPHI1, LALP1,RHO2, LPHI2, LALP2)
        end
    end

    # Thermalization
    for j in 1:nth
        k = k + 1
        dh[k], acc[k] = HMC!(U,Phi, dt,nsteps,lp, gp, sp, ymws, sws)
        pl[k]  = plaquette(U,lp, gp, ymws)
        # observables φ1
        rho1[k],Lphi1[k],Lalp1[k] = scalar_obs(U, Phi, 1, sp, lp, ymws)
        # observables φ2
        rho2[k],Lphi2[k],Lalp2[k] = scalar_obs(U, Phi, 2, sp, lp, ymws)

        @printf("  THM %d/%d (kappa: %4.3f):   %s   %6.2e    %20.12e    %20.12e    %20.12e    %20.12e   %20.12e    %20.12e    %20.12e\n",
                j, nth, sp.kap[1], acc[k] ? "true " : "false", dh[k],
                pl[k], rho1[k], Lphi1[k], Lalp1[k],rho2[k], Lphi2[k], Lalp2[k])
    end
    println(" ")

    # MC chain
    accepted = 0
    for j in 1:niter
        k = k + 1
        dh[k], acc[k] = HMC!(U,Phi, dt,nsteps,lp, gp, sp, ymws, sws)
        if acc[k]
            accepted+=1
        end
        pl[k]  = plaquette(U,lp, gp, ymws)
        rho1[k],Lphi1[k],Lalp1[k] = scalar_obs(U, Phi,1, sp, lp, ymws)
        rho2[k],Lphi2[k],Lalp2[k] = scalar_obs(U, Phi,2, sp, lp, ymws)
        
        @printf("  MSM %d/%d (kappa: %4.3f):   %s   %6.2e    %20.12e    %20.12e    %20.12e    %20.12e   %20.12e    %20.12e    %20.12e\n",
                j, niter, sp.kap[1], acc[k] ? "true " : "false", dh[k],
                pl[k], rho1[k], Lphi1[k], Lalp1[k],rho2[k], Lphi2[k], Lalp2[k])
    end

    # Write to BDIO
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 3, true)
    BDIO_write!(fb, pl[1:nth])
    BDIO_write_hash!(fb)
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 4, true)
    BDIO_write!(fb, pl[nth+1:end])
    BDIO_write_hash!(fb)

    BDIO_start_record!(fb, BDIO_BIN_F64LE, 5, true)
    BDIO_write!(fb, rho1[1:nth])
    BDIO_write!(fb, rho2[1:nth])
    BDIO_write_hash!(fb)
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 6, true)
    BDIO_write!(fb, rho1[nth+1:end])
    BDIO_write!(fb, rho2[nth+1:end])
    BDIO_write_hash!(fb)

    BDIO_start_record!(fb, BDIO_BIN_F64LE, 8, true)
    BDIO_write!(fb, Lphi1[1:nth])
    BDIO_write!(fb, Lphi2[1:nth])
    BDIO_write_hash!(fb)
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 9, true)
    BDIO_write!(fb, Lphi1[nth+1:end])
    BDIO_write!(fb, Lphi2[nth+1:end])
    BDIO_write_hash!(fb)

    BDIO_start_record!(fb, BDIO_BIN_F64LE, 10, true)
    BDIO_write!(fb, Lalp1[1:nth])
    BDIO_write!(fb, Lalp2[1:nth])
    BDIO_write_hash!(fb)
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 11, true)
    BDIO_write!(fb, Lalp1[nth+1:end])
    BDIO_write!(fb, Lalp2[nth+1:end])
    BDIO_write_hash!(fb)

    BDIO_start_record!(fb, BDIO_BIN_F64LE, 12, true)
    BDIO_write!(fb, dh[1:nth])
    BDIO_write_hash!(fb)
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 13, true)
    BDIO_write!(fb, dh[nth+1:end])
    BDIO_write_hash!(fb)

    # acceptance
    BDIO_start_record!(fb, BDIO_BIN_F64LE, 14, true)
    BDIO_write!(fb, [accepted/niter])
    BDIO_write_hash!(fb)

    println("\n\n")
end

println("## Timming results")
print_timer(linechars = :ascii)

BDIO_close!(fb)

println("## END")
# END
