Working HMC for SU(2) and SU(3)

This commit is contained in:
Alberto Ramos 2021-07-28 13:50:32 +02:00
parent f6be70070e
commit c378648508
6 changed files with 119 additions and 40 deletions

View file

@ -31,7 +31,7 @@ SU3() = SU3(1.0,0.0,0.0,0.0,1.0,0.0)
inverse(a::SU3) = SU3(conj(a.u11),conj(a.u21),(a.u12*a.u23 - a.u13*a.u22), inverse(a::SU3) = SU3(conj(a.u11),conj(a.u21),(a.u12*a.u23 - a.u13*a.u22),
conj(a.u12),conj(a.u22),(a.u13*a.u21 - a.u11*a.u23)) conj(a.u12),conj(a.u22),(a.u13*a.u21 - a.u11*a.u23))
dag(a::SU3) = inverse(a) dag(a::SU3) = inverse(a)
tr(g::SU3) = a.u11+a.u22+a.u11*conj(a.u22)-a.u12*conj(a.u21) tr(a::SU3) = a.u11+a.u22+conj(a.u11*a.u22 - a.u12*a.u21)
function Base.:*(a::SU3,b::SU3) function Base.:*(a::SU3,b::SU3)
@ -49,33 +49,33 @@ end
function Base.:/(a::SU3,b::SU3) function Base.:/(a::SU3,b::SU3)
bu31 = conj(b.u12*a.u23 - b.u13*b.u22) bu31 = (b.u12*b.u23 - b.u13*b.u22)
bu32 = conj(b.u13*b.u21 - b.u11*b.u23) bu32 = (b.u13*b.u21 - b.u11*b.u23)
bu33 = conj(b.u11*b.u22 - b.u12*b.u21) bu33 = (b.u11*b.u22 - b.u12*b.u21)
return SU3(a.u11*conj(b.u11) + a.u12*conj(b.u12) + a.u13*conj(b.u13), return SU3(a.u11*conj(b.u11) + a.u12*conj(b.u12) + a.u13*conj(b.u13),
a.u11*conj(b.u21) + a.u12*conj(b.u22) + a.u13*conj(b.u23), a.u11*conj(b.u21) + a.u12*conj(b.u22) + a.u13*conj(b.u23),
a.u11*conj(bu31) + a.u12*conj(bu32) + a.u13*conj(bu33), a.u11*(bu31) + a.u12*(bu32) + a.u13*(bu33),
a.u21*conj(b.u11) + a.u22*conj(b.u12) + a.u23*conj(b.u13), a.u21*conj(b.u11) + a.u22*conj(b.u12) + a.u23*conj(b.u13),
a.u21*conj(b.u21) + a.u22*conj(b.u22) + a.u23*conj(b.u23), a.u21*conj(b.u21) + a.u22*conj(b.u22) + a.u23*conj(b.u23),
a.u21*conj(bu31) + a.u22*conj(bu32) + a.u23*conj(bu33)) a.u21*(bu31) + a.u22*(bu32) + a.u23*(bu33))
end end
function Base.:\(a::SU3,b::SU3) function Base.:\(a::SU3,b::SU3)
au31 = conj(a.u12*a.u23 - a.u13*a.u22) au31 = (a.u12*a.u23 - a.u13*a.u22)
au32 = conj(a.u13*a.u21 - a.u11*a.u23) au32 = (a.u13*a.u21 - a.u11*a.u23)
bu31 = conj(b.u12*a.u23 - b.u13*b.u22) bu31 = conj(b.u12*b.u23 - b.u13*b.u22)
bu32 = conj(b.u13*b.u21 - b.u11*b.u23) bu32 = conj(b.u13*b.u21 - b.u11*b.u23)
bu33 = conj(b.u11*b.u22 - b.u12*b.u21) bu33 = conj(b.u11*b.u22 - b.u12*b.u21)
return SU3(conj(a.u11)*b.u11 + conj(a.u21)*b.u21 + conj(au31)*bu31, return SU3(conj(a.u11)*b.u11 + conj(a.u21)*b.u21 + (au31)*bu31,
conj(a.u11)*b.u12 + conj(a.u21)*b.u22 + conj(au31)*bu32, conj(a.u11)*b.u12 + conj(a.u21)*b.u22 + (au31)*bu32,
conj(a.u11)*b.u13 + conj(a.u21)*b.u23 + conj(au31)*bu33, conj(a.u11)*b.u13 + conj(a.u21)*b.u23 + (au31)*bu33,
conj(a.u12)*b.u11 + conj(a.u22)*b.u21 + conj(au32)*bu31, conj(a.u12)*b.u11 + conj(a.u22)*b.u21 + (au32)*bu31,
conj(a.u12)*b.u12 + conj(a.u22)*b.u22 + conj(au32)*bu32, conj(a.u12)*b.u12 + conj(a.u22)*b.u22 + (au32)*bu32,
conj(a.u12)*b.u13 + conj(a.u22)*b.u23 + conj(au32)*bu33) conj(a.u12)*b.u13 + conj(a.u22)*b.u23 + (au32)*bu33)
end end
@ -97,9 +97,7 @@ function projalg(a::SU3)
sr3ov2::Float64 = 0.866025403784438646763723170752 sr3ov2::Float64 = 0.866025403784438646763723170752
au33 = conj(a.u11*a.u22 - a.u12*a.u21) ditr = ( imag(a.u11) + imag(a.u22) + 2.0*imag(a.u11*a.u22 - a.u12*a.u21) )/3.0
ditr = ( imag(a.u11) + imag(a.u22) + imag(au33) )/3.0
m12 = (a.u12 - conj(a.u21))/2.0 m12 = (a.u12 - conj(a.u21))/2.0
m13 = (a.u13 - (a.u12*a.u23 - a.u13*a.u22) )/2.0 m13 = (a.u13 - (a.u12*a.u23 - a.u13*a.u22) )/2.0
m23 = (a.u23 - (a.u13*a.u21 - a.u11*a.u23) )/2.0 m23 = (a.u23 - (a.u13*a.u21 - a.u11*a.u23) )/2.0
@ -107,8 +105,9 @@ function projalg(a::SU3)
return SU3alg(imag( m12 ), imag( m13 ), imag( m23 ), return SU3alg(imag( m12 ), imag( m13 ), imag( m23 ),
real( m12 ), real( m13 ), real( m23 ), real( m12 ), real( m13 ), real( m23 ),
(imag(a.u11)-imag(a.u22))/2.0, (imag(a.u11)-imag(a.u22))/2.0,
-sr3ov2*(imag(au33)-ditr)) sr3ov2*(ditr))
end end
dot(a::SU3alg,b::SU3alg) = a.t1*b.t1 + a.t2*b.t2 + a.t3*b.t3 + a.t4*b.t4 + dot(a::SU3alg,b::SU3alg) = a.t1*b.t1 + a.t2*b.t2 + a.t3*b.t3 + a.t4*b.t4 +
a.t5*b.t5 + a.t6*b.t6 + a.t7*b.t7 + a.t8*b.t8 a.t5*b.t5 + a.t6*b.t6 + a.t7*b.t7 + a.t8*b.t8
norm2(a::SU3alg) = a.t1^2 + a.t2^2 + a.t3^2 + a.t4^2 + a.t5^2 + a.t6^2 + a.t7^2 + a.t8^2 norm2(a::SU3alg) = a.t1^2 + a.t2^2 + a.t3^2 + a.t4^2 + a.t5^2 + a.t6^2 + a.t7^2 + a.t8^2
@ -184,19 +183,19 @@ end
function Base.:/(a::M3x3,b::SU3) function Base.:/(a::M3x3,b::SU3)
bu31 = conj(b.u12*b.u23 - b.u13*b.u22) bu31 = (b.u12*b.u23 - b.u13*b.u22)
bu32 = conj(b.u13*b.u21 - b.u11*b.u23) bu32 = (b.u13*b.u21 - b.u11*b.u23)
bu33 = conj(b.u11*b.u22 - b.u12*b.u21) bu33 = (b.u11*b.u22 - b.u12*b.u21)
return M3x3(a.u11*conj(b.u11) + a.u12*conj(b.u12) + a.u13*conj(b.u13), return M3x3(a.u11*conj(b.u11) + a.u12*conj(b.u12) + a.u13*conj(b.u13),
a.u11*conj(b.u21) + a.u12*conj(b.u22) + a.u13*conj(b.u23), a.u11*conj(b.u21) + a.u12*conj(b.u22) + a.u13*conj(b.u23),
a.u11*conj(bu31) + a.u12*conj(bu32) + a.u13*conj(bu33), a.u11*(bu31) + a.u12*(bu32) + a.u13*(bu33),
a.u21*conj(b.u11) + a.u22*conj(b.u12) + a.u23*conj(b.u13), a.u21*conj(b.u11) + a.u22*conj(b.u12) + a.u23*conj(b.u13),
a.u21*conj(b.u21) + a.u22*conj(b.u22) + a.u23*conj(b.u23), a.u21*conj(b.u21) + a.u22*conj(b.u22) + a.u23*conj(b.u23),
a.u21*conj(bu31) + a.u22*conj(bu32) + a.u23*conj(bu33), a.u21*(bu31) + a.u22*(bu32) + a.u23*(bu33),
a.u31*conj(b.u11) + a.u32*conj(b.u12) + a.u33*conj(b.u13), a.u31*conj(b.u11) + a.u32*conj(b.u12) + a.u33*conj(b.u13),
a.u31*conj(b.u21) + a.u32*conj(b.u22) + a.u33*conj(b.u23), a.u31*conj(b.u21) + a.u32*conj(b.u22) + a.u33*conj(b.u23),
a.u31*conj(bu31) + a.u32*conj(bu32) + a.u33*conj(bu33)) a.u31*(bu31) + a.u32*(bu32) + a.u33*(bu33))
end end
Base.:*(a::Number,b::M3x3) = M3x3(a*b.u11, a*b.u12, a*bu13, Base.:*(a::Number,b::M3x3) = M3x3(a*b.u11, a*b.u12, a*bu13,

View file

@ -126,8 +126,22 @@ end
return s return s
end end
@inline map2latt(th::NTuple{3,Int64},bl::NTuple{3,Int64}) = CartesianIndex(th[1],bl[1],bl[2],bl[3]) @inline map2latt(th::NTuple{3,Int64},bl::NTuple{3,Int64}) = CartesianIndex(th[1],bl[3],bl[1],bl[2])
@inline function map2latt(th::NTuple{3,Int64},bl::NTuple{3,Int64}, lp)
i1 = mod1(th[1], lp.lblock[1]) + mod(bl[1], div(lp.iL[1],lp.lblock[1]))*lp.lblock[1]
i2 = div(th[1]-1, lp.lblock[1]) + 1 + div(bl[1]-1, div(lp.iL[1],lp.lblock[1]))*lp.lblock[2]
i3 = (bl[2] - 1) * lp.lblock[3] + th[2]
i4 = (bl[3] - 1) * lp.lblock[4] + th[3]
return CartesianIndex(i1,i2,i3,i4)
end
export map2latt, up, dw, shift export map2latt, up, dw, shift
end end

View file

@ -43,6 +43,16 @@ struct YMworkspace
rs = zeros(Float64, lp.iL...) rs = zeros(Float64, lp.iL...)
return new(f1, f2, mm, u1, replace_storage(CuArray, cs)) return new(f1, f2, mm, u1, replace_storage(CuArray, cs))
end end
if (T == SU3)
f1 = field(SU3alg, lp)
f2 = field(SU3alg, lp)
mm = field(SU3alg, lp)
u1 = field(SU3, lp)
cs = zeros(ComplexF64,lp.iL...)
rs = zeros(Float64, lp.iL...)
return new(f1, f2, mm, u1, replace_storage(CuArray, cs))
end
return nothing return nothing
end end
end end

View file

@ -17,7 +17,7 @@ function krnl_plaq!(plx, U, ipl, lp::SpaceParm)
Xu1 = up(X, id1, lp) Xu1 = up(X, id1, lp)
Xu2 = up(X, id2, lp) Xu2 = up(X, id2, lp)
plx[X] = tr(U[X, id1]*U[Xu1, id2] / (U[X, id2]*U[Xu2, id1])) @inbounds plx[X] = tr(U[X, id1]*U[Xu1, id2] / (U[X, id2]*U[Xu2, id1]))
return nothing return nothing
end end
@ -27,8 +27,8 @@ function krnl_plaq!(plx, U, lp::SpaceParm)
X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z), X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z),
(CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z)) (CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z))
plx[X] = complex(0.0) @inbounds plx[X] = complex(0.0)
for ipl in 1:lp.npls @inbounds for ipl in 1:lp.npls
id1, id2 = lp.plidx[ipl] id1, id2 = lp.plidx[ipl]
Xu1 = up(X, id1, lp) Xu1 = up(X, id1, lp)
Xu2 = up(X, id2, lp) Xu2 = up(X, id2, lp)
@ -55,11 +55,13 @@ function krnl_force_wilson_pln!(frc1, frc2, U, ipl, lp::SpaceParm, gp::GaugeParm
F2 = projalg(a*b) F2 = projalg(a*b)
F3 = projalg(b*a) F3 = projalg(b*a)
frc1[X ,id1] -= F1 @inbounds begin
frc1[X ,id2] += F1 frc1[X ,id1] -= F1
frc2[Xu1,id2] -= F2 frc1[X ,id2] += F1
frc2[Xu2,id1] += F3 frc2[Xu1,id2] -= F2
frc2[Xu2,id1] += F3
end
return nothing return nothing
end end

View file

@ -21,6 +21,10 @@ function field(::Type{T}, lp::SpaceParm) where {T <: Union{Group, Algebra}}
zeros(Float64, sz))) zeros(Float64, sz)))
elseif (T == SU3) elseif (T == SU3)
As = StructArray{SU3}((ones(ComplexF64, sz), zeros(ComplexF64, sz), zeros(ComplexF64, sz), zeros(ComplexF64, sz), ones(ComplexF64, sz), zeros(ComplexF64, sz))) As = StructArray{SU3}((ones(ComplexF64, sz), zeros(ComplexF64, sz), zeros(ComplexF64, sz), zeros(ComplexF64, sz), ones(ComplexF64, sz), zeros(ComplexF64, sz)))
# As = Array{SU3, lp.ndim+1}(undef, sz)
# CUDA.@sync begin
# CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_SU3_zero!(As, lp)
# end
elseif (T == SU3alg) elseif (T == SU3alg)
As = StructArray{SU3alg}((zeros(Float64, sz), As = StructArray{SU3alg}((zeros(Float64, sz),
zeros(Float64, sz), zeros(Float64, sz),
@ -30,6 +34,11 @@ function field(::Type{T}, lp::SpaceParm) where {T <: Union{Group, Algebra}}
zeros(Float64, sz), zeros(Float64, sz),
zeros(Float64, sz), zeros(Float64, sz),
zeros(Float64, sz))) zeros(Float64, sz)))
# As = Array{SU3alg, lp.ndim+1}(undef, sz)
# CUDA.@sync begin
# CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_SU3alg_zero!(As, lp)
# end
end end
return replace_storage(CuArray, As) return replace_storage(CuArray, As)
@ -71,6 +80,9 @@ function zero!(X)
fill!(LazyRows(X).t6, 0.0) fill!(LazyRows(X).t6, 0.0)
fill!(LazyRows(X).t7, 0.0) fill!(LazyRows(X).t7, 0.0)
fill!(LazyRows(X).t8, 0.0) fill!(LazyRows(X).t8, 0.0)
# CUDA.@sync begin
# CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_SU3alg_zero!(X, lp)
# end
end end
if (eltype(X) == SU2) if (eltype(X) == SU2)
@ -85,6 +97,9 @@ function zero!(X)
fill!(LazyRows(X).u21, complex(0.0)) fill!(LazyRows(X).u21, complex(0.0))
fill!(LazyRows(X).u22, complex(1.0)) fill!(LazyRows(X).u22, complex(1.0))
fill!(LazyRows(X).u23, complex(0.0)) fill!(LazyRows(X).u23, complex(0.0))
# CUDA.@sync begin
# CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_SU3_zero!(X, lp)
# end
end end
return nothing return nothing
@ -97,7 +112,7 @@ function norm2(X)
d = CUDA.mapreduce(x->x^2, +, LazyRows(X).t1) + d = CUDA.mapreduce(x->x^2, +, LazyRows(X).t1) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t2) + CUDA.mapreduce(x->x^2, +, LazyRows(X).t2) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t3) CUDA.mapreduce(x->x^2, +, LazyRows(X).t3)
elseif (eltype(X) == SU2alg) elseif (eltype(X) == SU3alg)
d = CUDA.mapreduce(x->x^2, +, LazyRows(X).t1) + d = CUDA.mapreduce(x->x^2, +, LazyRows(X).t1) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t2) + CUDA.mapreduce(x->x^2, +, LazyRows(X).t2) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t3) + CUDA.mapreduce(x->x^2, +, LazyRows(X).t3) +
@ -106,7 +121,42 @@ function norm2(X)
CUDA.mapreduce(x->x^2, +, LazyRows(X).t6) + CUDA.mapreduce(x->x^2, +, LazyRows(X).t6) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t7) + CUDA.mapreduce(x->x^2, +, LazyRows(X).t7) +
CUDA.mapreduce(x->x^2, +, LazyRows(X).t8) CUDA.mapreduce(x->x^2, +, LazyRows(X).t8)
# d = CUDA.mapreduce(norm2, +, X)
end end
return d return d
end end
function krnl_SU3_zero!(G, lp::SpaceParm)
X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z),
(CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z))
for id in 1:lp.ndim
G[X,id].u11 = complex(1.0)
G[X,id].u12 = complex(0.0)
G[X,id].u13 = complex(0.0)
G[X,id].u21 = complex(0.0)
G[X,id].u22 = complex(1.0)
G[X,id].u23 = complex(0.0)
end
return nothing
end
function krnl_SU3alg_zero!(G, lp::SpaceParm)
X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z),
(CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z))
for id in 1:lp.ndim
G[X,id].t1 = 0.0
G[X,id].t2 = 0.0
G[X,id].t3 = 0.0
G[X,id].t4 = 0.0
G[X,id].t5 = 0.0
G[X,id].t6 = 0.0
G[X,id].t7 = 0.0
G[X,id].t8 = 0.0
end
return nothing
end

View file

@ -15,7 +15,7 @@ function gauge_action(U, lp::SpaceParm, gp::GaugeParm, kp::KernelParm, ymws::YMw
CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_plaq!(ymws.cm, U, lp) CUDA.@cuda threads=kp.threads blocks=kp.blocks krnl_plaq!(ymws.cm, U, lp)
end end
S = gp.beta*( prod(lp.iL)*lp.npls - S = gp.beta*( prod(lp.iL)*lp.npls -
CUDA.mapreduce(real, +, real.(ymws.cm))/gp.ng ) CUDA.mapreduce(real, +, ymws.cm)/gp.ng )
return S return S
end end
@ -29,8 +29,12 @@ function plaquette(U, lp::SpaceParm, gp::GaugeParm, kp::KernelParm, ymws::YMwork
return CUDA.mapreduce(real, +, real.(ymws.cm))/(prod(lp.iL)*lp.npls) return CUDA.mapreduce(real, +, real.(ymws.cm))/(prod(lp.iL)*lp.npls)
end end
hamiltonian(mom, U, lp, gp, kp, ymws) = norm2(mom)/2.0 + function hamiltonian(mom, U, lp, gp, kp, ymws)
gauge_action(U, lp, gp, kp, ymws) K = norm2(mom)/2.0
V = gauge_action(U, lp, gp, kp, ymws)
println("K: ", K, " V: ", V)
return K+V
end
function HMC!(U, eps, ns, lp::SpaceParm, gp::GaugeParm, kp::KernelParm, ymws::YMworkspace; noacc=false) function HMC!(U, eps, ns, lp::SpaceParm, gp::GaugeParm, kp::KernelParm, ymws::YMworkspace; noacc=false)
@ -65,9 +69,9 @@ function krnl_updt!(mom, frc1, frc2, eps1, U, eps2, lp::SpaceParm)
X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z), X = map2latt((CUDA.threadIdx().x,CUDA.threadIdx().y,CUDA.threadIdx().z),
(CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z)) (CUDA.blockIdx().x,CUDA.blockIdx().y,CUDA.blockIdx().z))
for id in 1:lp.ndim @inbounds for id in 1:lp.ndim
mom[X,id] = mom[X,id] + eps1 * (frc1[X,id]+frc2[X,id]) mom[X,id] = mom[X,id] + eps1 * (frc1[X,id]+frc2[X,id])
U[X,id] = expm(U[X,id], mom[X,id], eps2) U[X,id] = expm(U[X,id],mom[X,id], eps2)
end end
return nothing return nothing