Reduction arrays have as argument CartesianIndex

This commit is contained in:
Alberto Ramos 2021-10-17 22:48:30 +02:00
parent 79a6e21cbc
commit 428f9782ce
6 changed files with 72 additions and 21 deletions

View file

@ -19,7 +19,9 @@ vector_field(::Type{T}, lp::SpaceParm) where {T} = CuArray{T, 3}(undef, lp.b
scalar_field(::Type{T}, lp::SpaceParm) where {T} = CuArray{T, 2}(undef, lp.bsz, lp.rsz) scalar_field(::Type{T}, lp::SpaceParm) where {T} = CuArray{T, 2}(undef, lp.bsz, lp.rsz)
nscalar_field(::Type{T}, n, lp::SpaceParm) where {T} = CuArray{T, 3}(undef, lp.bsz, n, lp.rsz) nscalar_field(::Type{T}, n, lp::SpaceParm) where {T} = CuArray{T, 3}(undef, lp.bsz, n, lp.rsz)
export vector_field, scalar_field, nscalar_field scalar_field_point(::Type{T}, lp::SpaceParm{N,M,D}) where {T,N,M,D} = CuArray{T, N}(undef, lp.iL...)
export vector_field, scalar_field, nscalar_field, scalar_field_point
end end

View file

@ -27,7 +27,7 @@ export up, dw, updw, global_point
include("Fields/Fields.jl") include("Fields/Fields.jl")
using .Fields using .Fields
export vector_field, scalar_field, nscalar_field export vector_field, scalar_field, nscalar_field, scalar_field_point
include("MD/MD.jl") include("MD/MD.jl")
using .MD using .MD

View file

@ -169,25 +169,67 @@ Given a point `x` with index `p`, this routine returns the index of the points
return bu, ru, bd, rd return bu, ru, bd, rd
end end
@inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm) @inline cntb(nb, id::Int64, lp::SpaceParm) = mod(div(nb-1,lp.blkS[id]),lp.blk[id])
@inline cntr(nr, id::Int64, lp::SpaceParm) = mod(div(nr-1,lp.rbkS[id]),lp.rbk[id])
@inline cnt(nb, nr, id::Int64, lp::SpaceParm) = 1 + cntb(nb,id,lp) + cntr(nr,id,lp)*lp.blk[id]
@inline cntb(nb, id::Int64, lp::SpaceParm) = mod(div(nb-1,lp.blkS[id]),lp.blk[id]) @inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm{2,M,D}) where {M,D}
@inline cntr(nr, id::Int64, lp::SpaceParm) = mod(div(nr-1,lp.rbkS[id]),lp.rbk[id])
@inline cnt(nb, nr, id::Int64, lp::SpaceParm) = 1 + cntb(nb,id,lp) + cntr(nr,id,lp)*lp.blk[id] i1 = cnt(p[1], p[2], 1, lp)
i2 = cnt(p[1], p[2], 2, lp)
pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return pt # pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return CartesianIndex{4}(i1,i2)
end end
@inline function point_time(p::NTuple{2,Int64}, lp::SpaceParm) @inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm{3,M,D}) where {M,D}
@inline cntb(nb, id::Int64, lp::SpaceParm) = mod(div(nb-1,lp.blkS[id]),lp.blk[id]) i1 = cnt(p[1], p[2], 1, lp)
@inline cntr(nr, id::Int64, lp::SpaceParm) = mod(div(nr-1,lp.rbkS[id]),lp.rbk[id]) i2 = cnt(p[1], p[2], 2, lp)
i3 = cnt(p[1], p[2], 3, lp)
@inline cnt(nb, nr, id::Int64, lp::SpaceParm) = 1 + cntb(nb,id,lp) + cntr(nr,id,lp)*lp.blk[id] # pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return CartesianIndex{4}(i1,i2,i3)
return cnt(p[1], p[2], 1, lp) end
@inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm{4,M,D}) where {M,D}
i1 = cnt(p[1], p[2], 1, lp)
i2 = cnt(p[1], p[2], 2, lp)
i3 = cnt(p[1], p[2], 3, lp)
i4 = cnt(p[1], p[2], 4, lp)
# pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return CartesianIndex{4}(i1,i2,i3,i4)
end
@inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm{5,M,D}) where {M,D}
i1 = cnt(p[1], p[2], 1, lp)
i2 = cnt(p[1], p[2], 2, lp)
i3 = cnt(p[1], p[2], 3, lp)
i4 = cnt(p[1], p[2], 4, lp)
i5 = cnt(p[1], p[2], 5, lp)
# pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return CartesianIndex{4}(i1,i2,i3,i4,i5)
end
@inline function point_coord(p::NTuple{2,Int64}, lp::SpaceParm{6,M,D}) where {M,D}
i1 = cnt(p[1], p[2], 1, lp)
i2 = cnt(p[1], p[2], 2, lp)
i3 = cnt(p[1], p[2], 3, lp)
i4 = cnt(p[1], p[2], 4, lp)
i5 = cnt(p[1], p[2], 5, lp)
i6 = cnt(p[1], p[2], 6, lp)
# pt = ntuple(i -> cnt(p[1], p[2], i, lp), lp.ndim)
return CartesianIndex{4}(i1,i2,i3,i4,i5,i6)
end
@inline function point_time(p::NTuple{2,Int64}, lp::SpaceParm{N,M,D}) where {N,M,D}
return cnt(p[1], p[2], N, lp)
end end

View file

@ -66,8 +66,8 @@ struct YMworkspace{T}
mm = vector_field(SU3alg{T}, lp) mm = vector_field(SU3alg{T}, lp)
u1 = vector_field(SU3{T}, lp) u1 = vector_field(SU3{T}, lp)
end end
cs = scalar_field(Complex{T}, lp) cs = scalar_field_point(Complex{T}, lp)
rs = scalar_field(T, lp) rs = scalar_field_point(T, lp)
end end
return new{T}(GRP,ALG,T,f1, f2, mm, u1, cs, rs) return new{T}(GRP,ALG,T,f1, f2, mm, u1, cs, rs)

View file

@ -15,7 +15,7 @@ function krnl_impr!(plx, U::AbstractArray{T}, c0, c1, lp::SpaceParm{N,M,D}) wher
Ush = @cuStaticSharedMem(T, (D,2)) Ush = @cuStaticSharedMem(T, (D,2))
plx[b,r] = zero(plx[b,r]) S = zero(eltype(plx))
for id1 in 1:N-1 for id1 in 1:N-1
bu1, ru1 = up((b, r), id1, lp) bu1, ru1 = up((b, r), id1, lp)
Ush[b,1] = U[b,id1,r] Ush[b,1] = U[b,id1,r]
@ -85,10 +85,13 @@ function krnl_impr!(plx, U::AbstractArray{T}, c0, c1, lp::SpaceParm{N,M,D}) wher
g2 = Ush[b,2]\Ush[b,1] g2 = Ush[b,2]\Ush[b,1]
plx[b,r] += c0*tr(g2*ga/gb) + c1*( tr(g2*h2/gb) + tr(g2*ga/h3)) S += c0*tr(g2*ga/gb) + c1*( tr(g2*h2/gb) + tr(g2*ga/h3))
end end
end end
I = point_coord((b,r), lp)
plx[I] = S
return nothing return nothing
end end
@ -98,7 +101,7 @@ function krnl_plaq!(plx, U::AbstractArray{T}, lp::SpaceParm{N,M,D}) where {T,N,M
Ush = @cuStaticSharedMem(T, (D,2)) Ush = @cuStaticSharedMem(T, (D,2))
plx[b,r] = zero(plx[b,r]) S = zero(eltype(plx))
for id1 in 1:N-1 for id1 in 1:N-1
bu1, ru1 = up((b, r), id1, lp) bu1, ru1 = up((b, r), id1, lp)
Ush[b,1] = U[b,id1,r] Ush[b,1] = U[b,id1,r]
@ -119,10 +122,13 @@ function krnl_plaq!(plx, U::AbstractArray{T}, lp::SpaceParm{N,M,D}) where {T,N,M
gt2 = U[bu2,id1,ru2] gt2 = U[bu2,id1,ru2]
end end
plx[b,r] += tr(Ush[b,1]*gt1 / (Ush[b,2]*gt2)) S += tr(Ush[b,1]*gt1 / (Ush[b,2]*gt2))
end end
end end
I = point_coord((b,r), lp)
plx[I] = S
return nothing return nothing
end end

View file

@ -39,6 +39,7 @@ println("\n## WILSON ACTION/FLOW TIMES")
gp = GaugeParm{PREC}(6.0, 1.0, (0.0,0.0), 3) gp = GaugeParm{PREC}(6.0, 1.0, (0.0,0.0), 3)
println("Gauge Parameters: ", gp) println("Gauge Parameters: ", gp)
println("Initial Action: ") println("Initial Action: ")
@time S = gauge_action(U, lp, gp, ymws) @time S = gauge_action(U, lp, gp, ymws)