diff --git a/src/LatticeGPU.jl b/src/LatticeGPU.jl index 3ce860a..29a1b6c 100644 --- a/src/LatticeGPU.jl +++ b/src/LatticeGPU.jl @@ -47,7 +47,7 @@ export FlowIntr, wfl_euler, zfl_euler, wfl_rk2, zfl_rk2, wfl_rk3, zfl_rk3 export flw, flw_adapt export sfcoupling, bndfield, setbndfield export import_lex64, import_cern64, import_bsfqcd, save_cnfg, read_cnfg, read_gp -export updt_or_wilson! +export evenodd, updt_or_wilson! include("Spinors/Spinors.jl") diff --git a/src/YM/YM.jl b/src/YM/YM.jl index 079d941..5638182 100644 --- a/src/YM/YM.jl +++ b/src/YM/YM.jl @@ -182,6 +182,6 @@ include("YMio.jl") export import_lex64, import_cern64, import_bsfqcd, save_cnfg, read_cnfg, read_gp include("YMupdate.jl") -export updt_or_wilson! +export evenodd, updt_or_wilson! end diff --git a/src/YM/YMupdate.jl b/src/YM/YMupdate.jl index 0fd3a7b..03d5bb6 100644 --- a/src/YM/YMupdate.jl +++ b/src/YM/YMupdate.jl @@ -9,19 +9,83 @@ ### created: Mon Sep 1 08:59:22 2025 ### +struct EvenOdd{T} + state::Bool + e::T + o::T +end + +function evenodd(eo::EvenOdd, lp::SpaceParm) + + if eo.state + return eo + end + + return evenodd(lp) +end + +function evenodd(lp::SpaceParm) + + if all(lp.iL .% 2 .== 0) + e = Array{Int64,1}(undef, div(lp.bsz,2)) + o = Array{Int64,1}(undef, div(lp.bsz,2)) + + ie = 1 + io = 1 + for b in 1:lp.bsz + clr = point_color((b, 1), lp) + if clr % 2 == 0 + e[ie] = b + ie = ie + 1 + else + o[io] = b + io = io + 1 + end + end + return EvenOdd(true, CuArray(e), CuArray(o)) + else + return EvenOdd(false,nothing, nothing) + end + +end + +function updt_or_wilson!(U, gp::GaugeParm, eo::EvenOdd, lp::SpaceParm{N,M,B,D}) where {N,M,B,D} + + if eo.state + @timeit "OR update (Wilson action [eo])" begin + ztw = ztwist(gp, lp) + for id in 1:N + CUDA.@sync begin + CUDA.@cuda threads=div(lp.bsz,2) blocks=lp.rsz krnl_or_wilson!(U, id, eo.e, ztw, gp, lp) + end + end + + for id in 1:N + CUDA.@sync begin + CUDA.@cuda threads=div(lp.bsz,2) blocks=lp.rsz krnl_or_wilson!(U, id, eo.o, ztw, gp, lp) + end + end + end + else + updt_or_wilson!(U, gp, lp) + end + + return nothing +end + function updt_or_wilson!(U, gp::GaugeParm, lp::SpaceParm{N,M,B,D}) where {N,M,B,D} @timeit "OR update (Wilson action)" begin ztw = ztwist(gp, lp) for id in 1:N CUDA.@sync begin - CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_or_wilson!(U, id, 0, ztw, lp) + CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_or_wilson!(U, id, 0, ztw, gp, lp) end end for id in 1:N CUDA.@sync begin - CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_or_wilson!(U, id, 1, ztw, lp) + CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_or_wilson!(U, id, 1, ztw, gp, lp) end end end @@ -30,19 +94,28 @@ function updt_or_wilson!(U, gp::GaugeParm, lp::SpaceParm{N,M,B,D}) where {N,M,B, end -function krnl_or_wilson!(U::AbstractArray{SU3{T}}, id, icl, ztw, lp::SpaceParm) where {T} +function krnl_or_wilson!(U::AbstractArray{SU3{T}}, id, icl::Int64, ztw, gp::GaugeParm, lp::SpaceParm) where {T} b = Int64(CUDA.threadIdx().x) r = Int64(CUDA.blockIdx().x) clr = point_color((b, r), lp) if clr % 2 == icl - stp = sum_staples(U, id, (b, r), ztw, lp) + stp = sum_staples(U, id, (b, r), ztw, gp, lp) U[b,id,r] = do_or(U[b,id,r], stp) end return nothing end +function krnl_or_wilson!(U::AbstractArray{SU3{T}}, id, eo::AbstractArray, ztw, gp::GaugeParm, lp::SpaceParm) where {T} + + b = eo[Int64(CUDA.threadIdx().x)] + r = Int64(CUDA.blockIdx().x) + stp = sum_staples(U, id, (b, r), ztw, gp, lp) + U[b,id,r] = do_or(U[b,id,r], stp) + + return nothing +end @inline function overrelax_cd(w::Tuple{T,T}) where {T} @@ -56,44 +129,6 @@ end end -@inline function sum_staples(U::AbstractArray{SU3{T}},im,pt::NTuple{2,Int64}, ztw, lp::SpaceParm{N,M,BC_PERIODIC,D}) where {T,N,M,D} - - I = point_coord(pt, lp) - bum, rum = up(pt, im, lp) - stp = zero(M3x3{T}) - for in in 1:N - if (in != im) - zf = one(ztw[begin]) - for i in 1:length(lp.plidx) - if (lp.plidx[i] == (in, im)) || (lp.plidx[i] == (im, in)) - if im > in - zf = ztw[i] - else - zf = conj(ztw[i]) - end - break - end - end - - bun, run, bdn, rdn = updw(pt, in, lp) - bd, rd = up((bdn, rdn), im, lp) - m1 = convert(M3x3{T}, U[bum,in,rum]/U[bun,im,run]/U[pt[1],in,pt[2]]) - if (I[im]==1) && (I[in]==1) - m1 = m1 * zf - end - m2 = convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) - Id = point_coord((bdn, rdn), lp) - if (Id[im]==1) && (Id[in]==1) - m2 = m2 * conj(zf) - end - - stp = stp + m1 + m2 - end - end - - return stp -end - @inline function do_or(U::SU3{T}, stp) where {T} M = do_or12(U, stp) @@ -172,4 +207,118 @@ end return SU3{T}(u11, u12, u13, u21, u22, u23) end +@inline function sum_staples(U::AbstractArray{SU3{T}},im,pt::NTuple{2,Int64}, ztw, gp::GaugeParm, lp::SpaceParm{N,M,BC_PERIODIC,D}) where {T,N,M,D} + + I = point_coord(pt, lp) + bum, rum = up(pt, im, lp) + stp = zero(M3x3{T}) + for in in 1:N + if (in != im) + zf = one(ztw[begin]) + for i in 1:length(lp.plidx) + if (lp.plidx[i] == (in, im)) || (lp.plidx[i] == (im, in)) + if im > in + zf = ztw[i] + else + zf = conj(ztw[i]) + end + break + end + end + + bun, run, bdn, rdn = updw(pt, in, lp) + bd, rd = up((bdn, rdn), im, lp) + m1 = convert(M3x3{T}, U[bum,in,rum]/U[bun,im,run]/U[pt[1],in,pt[2]]) + if (I[im]==1) && (I[in]==1) + m1 = m1 * zf + end + m2 = convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) + Id = point_coord((bdn, rdn), lp) + if (Id[im]==1) && (Id[in]==1) + m2 = m2 * conj(zf) + end + + stp = stp + m1 + m2 + end + end + + return stp +end + +@inline function sum_staples(U::AbstractArray{SU3{T}},im,pt::NTuple{2,Int64}, ztw, gp::GaugeParm, lp::SpaceParm{N,M,BC_SF_ORBI,D}) where {T,N,M,D} + + I = point_coord(pt, lp) + bum, rum = up(pt, im, lp) + stp = zero(M3x3{T}) + for in in 1:N + if (in != im) + bun, run, bdn, rdn = updw(pt, in, lp) + bd, rd = up((bdn, rdn), im, lp) + + if (I[end] == lp.iL[end]) + if (in==N) + m1 = gp.cG[1]*convert(M3x3{T}, U[bum,in,rum]/gp.Ubnd[im]/U[pt[1],in,pt[2]]) + elseif (im==N) + m1 = gp.cG[1]*convert(M3x3{T}, gp.Ubnd[in]/U[bun,im,run]/U[pt[1],in,pt[2]]) + else + m1 = zero(M3x3{T}) + end + else + m1 = convert(M3x3{T}, U[bum,in,rum]/U[bun,im,run]/U[pt[1],in,pt[2]]) + end + + if I[end] == 1 + if (in==N) || (im==N) + m2 = gp.cG[2]*convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) + else + m2 = zero(M3x3{T}) + end + else + m2 = convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) + end + stp = stp + m1 + m2 + + end + end + + return stp +end + +@inline function sum_staples(U::AbstractArray{SU3{T}},im,pt::NTuple{2,Int64}, ztw, gp::GaugeParm, lp::SpaceParm{N,M,BC_OPEN,D}) where {T,N,M,D} + + I = point_coord(pt, lp) + bum, rum = up(pt, im, lp) + stp = zero(M3x3{T}) + for in in 1:N + if (in != im) + bun, run, bdn, rdn = updw(pt, in, lp) + bd, rd = up((bdn, rdn), im, lp) + + if (I[end] == lp.iL[end]) + if (in==N) + m1 = gp.cG[1]*convert(M3x3{T}, U[bum,in,rum]/gp.Ubnd[im]/U[pt[1],in,pt[2]]) + else + m1 = zero(M3x3{T}) + end + else + m1 = convert(M3x3{T}, U[bum,in,rum]/U[bun,im,run]/U[pt[1],in,pt[2]]) + end + + if I[end] == 1 + if (in==N) || (im==N) + m2 = gp.cG[2]*convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) + else + m2 = zero(M3x3{T}) + end + else + m2 = convert(M3x3{T},dag(U[bdn,in,rdn]\U[bdn,im,rdn]*U[bd, in, rd])) + end + stp = stp + m1 + m2 + + end + end + + return stp +end + diff --git a/test/update/test1.jl b/test/update/test1.jl index 1a6d4ce..5679a65 100644 --- a/test/update/test1.jl +++ b/test/update/test1.jl @@ -9,26 +9,49 @@ ### created: Tue Sep 2 16:29:48 2025 ### -using LatticeGPU, Test, CUDA +using LatticeGPU, Test, CUDA, TimerOutputs T = Float64 -lp = SpaceParm{4}((4,4,4,4), (4,4,4,4), BC_PERIODIC, (0,0,0,0,0,1)) -gp = GaugeParm{T}(SU3{T}, 6.5, 1.0) +#lp = SpaceParm{4}((4,4,4,4), (4,4,4,4), BC_SF_ORBI, (0,0,0,0,0,0)) +lp = SpaceParm{4}((16,16,16,16), (4,4,4,4), BC_PERIODIC, (0,0,0,0,0,1)) +gp = GaugeParm{T}(SU3{T}, 6.5, 1.0, (1.5,1.5), (0.0,0.0), lp.iL[1]) ymws = YMworkspace(SU3, T, lp) +println(gp.Ubnd) randomize!(ymws.mom, lp, ymws) U = exp.(ymws.mom) act = gauge_action(U, lp, gp, ymws) plq = plaquette(U, lp, gp, ymws) pl_exact = Eoft_plaq(U, gp, lp, ymws) cl_exact = Eoft_clover(U, gp, lp, ymws) +println(lp) println("## Random config: ") println(" - act: ", act) println(" - plq: ", plq) println(" - Epl: ", pl_exact) println(" - Ecl: ", cl_exact) -updt_or_wilson!(U, gp, lp) +Ucp = copy(U) + +eo = evenodd(lp) +for i in 1:2000 + updt_or_wilson!(U, gp, eo, lp) +end + +plq = plaquette(U, lp, gp, ymws) +act = gauge_action(U, lp, gp, ymws) +pl_exact = Eoft_plaq(U, gp, lp, ymws) +cl_exact = Eoft_clover(U, gp, lp, ymws) +println("## After OR [eo]: ") +println(" - act: ", act) +println(" - plq: ", plq) +println(" - Epl: ", pl_exact) +println(" - Ecl: ", cl_exact) + +U .= Ucp +for i in 1:2000 + updt_or_wilson!(U, gp, lp) +end plq = plaquette(U, lp, gp, ymws) act = gauge_action(U, lp, gp, ymws) @@ -39,3 +62,5 @@ println(" - act: ", act) println(" - plq: ", plq) println(" - Epl: ", pl_exact) println(" - Ecl: ", cl_exact) + +print_timer(linechars = :ascii)