CG! scalar product and typos fixed

This commit is contained in:
Fernando P. Panadero 2023-11-20 14:39:52 +01:00
parent 5048fc85fa
commit dea04bccff

View file

@ -14,21 +14,43 @@
Solves the linear equation `Ax = si` Solves the linear equation `Ax = si`
""" """
function CG!(si, U, m0, A, lp::SpaceParm, dws::DiracWorkspace) function krnl_dot!(sum,fone,ftwo)
b=Int64(CUDA.threadIdx().x)
r=Int64(CUDA.blockIdx().x)
sum[b,r] = dot(fone[b,r],ftwo[b,r])
return nothing
end
function field_dot(fone::AbstractArray,ftwo::AbstractArray,sumf,lp) where {T}
CUDA.@sync begin
CUDA.@cuda threads=lp.bsz blocks=lp.rsz krnl_dot!(sumf,fone,ftwo)
end
return sum(sumf)
end
function CG!(si, U, A, dpar::DiracParam, lp::SpaceParm, dws::DiracWorkspace{T}, maxiter::Int64 = 10, tol=1.0) where {T}
dws.sr .= si dws.sr .= si
dws.sp .= si dws.sp .= si
norm = CUDA.mapreduce(x -> norm2(x), +, si) norm = CUDA.mapreduce(x -> norm2(x), +, si)
fill!(si,zero(eltype(so))) fill!(si,zero(eltype(si)))
err = 0.0 err = 0.0
tol = eps * norm tol = tol * norm
iterations = 0 iterations = 0
sumf = scalar_field(Complex{T}, lp)
niter = 0 niter = 0
for i in 1:maxiter for i in 1:maxiter
A(dws.sAp, U, dws.sp, am0, dws.st, lp) A(dws.sAp, U, dws.sp, dpar, dws, lp)
prod = CUDA.mapreduce(x -> dot(x[1],x[2]), +, zip(dws.sp, dws.sAp))
prod = field_dot(dws.sp,dws.sAp,sumf,lp)
alpha = norm/prod alpha = norm/prod
si .= si .+ alpha .* dws.sp si .= si .+ alpha .* dws.sp