From 7a51966eba8409c02dc195fbec2e06b002aac4c0 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 22 Sep 2025 16:04:47 -0400 Subject: [PATCH 01/55] Add `Problem` as type parameter to `SweepIterator` RegionPlan is ommited as this is just vector of kwargs whos type is unimportant --- src/solvers/iterators.jl | 90 +++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 023818ba..e082db56 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -1,44 +1,3 @@ -# -# SweepIterator -# - -mutable struct SweepIterator - sweep_kws - region_iter - which_sweep::Int -end - -problem(S::SweepIterator) = problem(S.region_iter) - -Base.length(S::SweepIterator) = length(S.sweep_kws) - -function Base.iterate(S::SweepIterator, which=nothing) - if isnothing(which) - sweep_kws_state = iterate(S.sweep_kws) - else - sweep_kws_state = iterate(S.sweep_kws, which) - end - isnothing(sweep_kws_state) && return nothing - current_sweep_kws, next = sweep_kws_state - - if !isnothing(which) - S.region_iter = region_iterator( - problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws... - ) - end - S.which_sweep += 1 - return S.region_iter, next -end - -function sweep_iterator(problem, sweep_kws) - region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...) - return SweepIterator(sweep_kws, region_iter, 1) -end - -function sweep_iterator(problem, nsweeps::Integer; sweep_kws...) - return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps)) -end - # # RegionIterator # @@ -54,10 +13,14 @@ current_region_plan(R::RegionIterator) = R.region_plan[R.which_region] current_region(R::RegionIterator) = current_region_plan(R)[1] region_kwargs(R::RegionIterator) = current_region_plan(R)[2] function previous_region(R::RegionIterator) - R.which_region==1 ? nothing : R.region_plan[R.which_region - 1][1] + return R.which_region == 1 ? nothing : R.region_plan[R.which_region - 1][1] end function next_region(R::RegionIterator) - R.which_region==length(R.region_plan) ? nothing : R.region_plan[R.which_region + 1][1] + return if R.which_region == length(R.region_plan) + nothing + else + R.region_plan[R.which_region + 1][1] + end end is_last_region(R::RegionIterator) = isnothing(next_region(R)) @@ -98,3 +61,44 @@ end function region_plan(problem; kws...) return euler_sweep(state(problem); kws...) end + +# +# SweepIterator +# + +mutable struct SweepIterator{Problem} + sweep_kws + region_iter::RegionIterator{Problem} + which_sweep::Int +end + +problem(S::SweepIterator) = problem(S.region_iter) + +Base.length(S::SweepIterator) = length(S.sweep_kws) + +function Base.iterate(S::SweepIterator, which=nothing) + if isnothing(which) + sweep_kws_state = iterate(S.sweep_kws) + else + sweep_kws_state = iterate(S.sweep_kws, which) + end + isnothing(sweep_kws_state) && return nothing + current_sweep_kws, next = sweep_kws_state + + if !isnothing(which) + S.region_iter = region_iterator( + problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws... + ) + end + S.which_sweep += 1 + return S.region_iter, next +end + +function sweep_iterator(problem, sweep_kws) + region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...) + return SweepIterator(sweep_kws, region_iter, 1) +end + +function sweep_iterator(problem, nsweeps::Integer; sweep_kws...) + return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps)) +end From 245182b734cad640fbcad7084a3edee17a6c0d08 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 24 Sep 2025 15:09:36 -0400 Subject: [PATCH 02/55] Format test files and improve comparisons for readabilty on failure --- test/solvers/test_applyexp.jl | 40 +++++++++++++++++------------------ test/solvers/test_eigsolve.jl | 8 +++---- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/test/solvers/test_applyexp.jl b/test/solvers/test_applyexp.jl index 69b83b7c..acbd8318 100644 --- a/test/solvers/test_applyexp.jl +++ b/test/solvers/test_applyexp.jl @@ -12,12 +12,12 @@ function chain_plus_ancilla(; nchain) for j in 1:nchain add_vertex!(g, j) end - for j in 1:(nchain - 1) - add_edge!(g, j=>j+1) + for j in 1:(nchain-1) + add_edge!(g, j => j + 1) end # Add ancilla vertex near middle of chain add_vertex!(g, 0) - add_edge!(g, 0=>nchain÷2) + add_edge!(g, 0 => nchain ÷ 2) return g end @@ -31,10 +31,10 @@ end # Make Heisenberg model Hamiltonian h = OpSum() - for j in 1:(N - 1) - h += "Sz", j, "Sz", j+1 - h += 1/2, "S+", j, "S-", j+1 - h += 1/2, "S-", j, "S+", j+1 + for j in 1:(N-1) + h += "Sz", j, "Sz", j + 1 + h += 1 / 2, "S+", j, "S-", j + 1 + h += 1 / 2, "S-", j, "S+", j + 1 end H = ttn(h, sites) @@ -54,7 +54,7 @@ end E, gs_psi = dmrg(H, psi0; insert_kwargs=(; trunc), nsites, nsweeps, outputlevel) (outputlevel >= 1) && println("2-site DMRG energy = ", E) - insert_kwargs=(; trunc) + insert_kwargs = (; trunc) nsites = 1 tmax = 0.10 time_range = 0.0:0.02:tmax @@ -73,7 +73,7 @@ end # Test that accumulated phase angle is E*tmax z = inner(psi1_t, gs_psi) - @test abs(atan(imag(z)/real(z)) - E*tmax) < 1E-4 + @test atan(imag(z) / real(z)) ≈ E * tmax atol = 1E-4 end @testset "Applyexp Time Point Handling" begin @@ -83,10 +83,10 @@ end # Make Heisenberg model Hamiltonian h = OpSum() - for j in 1:(N - 1) - h += "Sz", j, "Sz", j+1 - h += 1/2, "S+", j, "S-", j+1 - h += 1/2, "S-", j, "S+", j+1 + for j in 1:(N-1) + h += "Sz", j, "Sz", j + 1 + h += 1 / 2, "S+", j, "S-", j + 1 + h += 1 / 2, "S-", j, "S+", j + 1 end H = ttn(h, sites) @@ -99,23 +99,23 @@ end nsites = 2 trunc = (; cutoff=1E-8, maxdim=100) - insert_kwargs=(; trunc) + insert_kwargs = (; trunc) # Test that all time points are reached and reported correctly - time_points = [0.0,0.1,0.25,0.32,0.4] + time_points = [0.0, 0.1, 0.25, 0.32, 0.4] times = Real[] function collect_times(problem; kws...) push!(times, ITensorNetworks.current_time(problem)) end - time_evolve(H, time_points, psi0; insert_kwargs, nsites, sweep_callback=collect_times,outputlevel=1) - @test norm(times - time_points) < 10*eps(Float64) + time_evolve(H, time_points, psi0; insert_kwargs, nsites, sweep_callback=collect_times, outputlevel=1) + @test times ≈ time_points atol = 10 * eps(Float64) # Test that all exponents are reached and reported correctly - exponent_points = [-0.0,-0.1,-0.25,-0.32,-0.4] + exponent_points = [-0.0, -0.1, -0.25, -0.32, -0.4] exponents = Real[] function collect_exponents(problem; kws...) push!(exponents, ITensorNetworks.current_exponent(problem)) end - applyexp(H, exponent_points, psi0; insert_kwargs, nsites, sweep_callback=collect_exponents,outputlevel=1) - @test norm(exponents - exponent_points) < 10*eps(Float64) + applyexp(H, exponent_points, psi0; insert_kwargs, nsites, sweep_callback=collect_exponents, outputlevel=1) + @test exponents ≈ exponent_points atol = 10 * eps(Float64) end diff --git a/test/solvers/test_eigsolve.jl b/test/solvers/test_eigsolve.jl index 75194cae..1d18a6d8 100644 --- a/test/solvers/test_eigsolve.jl +++ b/test/solvers/test_eigsolve.jl @@ -20,8 +20,8 @@ include("utilities/tree_graphs.jl") for edge in edges(sites) i, j = src(edge), dst(edge) h += "Sz", i, "Sz", j - h += 1/2, "S+", i, "S-", j - h += 1/2, "S-", i, "S+", j + h += 1 / 2, "S+", i, "S-", j + h += 1 / 2, "S-", i, "S+", j end H = ttn(h, sites) @@ -48,7 +48,7 @@ include("utilities/tree_graphs.jl") insert_kwargs = (; trunc) E, psi = dmrg(H, psi0; insert_kwargs, nsites, nsweeps, outputlevel) (outputlevel >= 1) && println("2-site DMRG energy = ", E) - @test abs(E-Ex) < 1E-5 + @test E ≈ Ex atol = 1E-5 # # Test 1-site DMRG with subspace expansion @@ -60,5 +60,5 @@ include("utilities/tree_graphs.jl") insert_kwargs = (; trunc) E, psi = dmrg(H, psi0; extract_kwargs, insert_kwargs, nsites, nsweeps, outputlevel) (outputlevel >= 1) && println("1-site+subspace DMRG energy = ", E) - @test abs(E-Ex) < 1E-5 + @test E ≈ Ex atol = 1E-5 end From 7af4b252e82a9f7dd56298e26a51de4ec2170d6a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 24 Sep 2025 15:23:21 -0400 Subject: [PATCH 03/55] Redesign iterator interface by introducing AbstractNetworkIterator abstract type Other changes: - Both `sweep_callback` and `region_callback` in `sweep_solve` now take only one positional argument, the sweep iterator. - Iterating `SweepIterator` now automatically performs the RegionIteration - Added an 'adapter' `PauseAfterIncrement` that allows `SweepIterator` to be iterated without performing region iteration - `RegionIteration` now tracks the current sweep number - Replaced some function calls with explict calls to constructors to make it clear when new iterators are being constructed (instead of returned from a field etc). Note, AbstractNetworkIterator interface requires some documentation. --- src/solvers/adapters.jl | 18 ++++ src/solvers/applyexp.jl | 39 ++++---- src/solvers/eigsolve.jl | 16 ++-- src/solvers/fitting.jl | 2 +- src/solvers/iterators.jl | 174 ++++++++++++++++++++++------------ src/solvers/sweep_solve.jl | 50 ++++------ test/solvers/test_applyexp.jl | 8 +- 7 files changed, 187 insertions(+), 120 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 7c033d8e..44c70b20 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -30,3 +30,21 @@ iterator which outputs a tuple of the form (current_region, current_region_kwarg at each step. """ region_tuples(R::RegionIterator) = TupleRegionIterator(R) + +""" + struct PauseAfterIncrement{S<:AbstractNetworkIterator} + +Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the +process. This allows one to manually call a custom `compute!` or insert their own code it in +the loop body in place of `compute!`. +""" +struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterator + parent::S +end + +done(NC::PauseAfterIncrement) = done(NC.parent) +state(NC::PauseAfterIncrement) = state(NC.parent) +increment!(NC::PauseAfterIncrement) = increment!(NC.parent) +compute!(NC::PauseAfterIncrement) = NC + +PauseAfterIncrement(NC::PauseAfterIncrement) = NC diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 8bc070e3..144a77e6 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -11,7 +11,7 @@ operator(A::ApplyExpProblem) = A.operator state(A::ApplyExpProblem) = A.state current_exponent(A::ApplyExpProblem) = A.current_exponent function current_time(A::ApplyExpProblem) - t = im*A.current_exponent + t = im * A.current_exponent return iszero(imag(t)) ? real(t) : t end @@ -36,9 +36,9 @@ function update( iszero(abs(exponent_step)) && return prob, local_state local_state, info = solver( - x->optimal_map(operator(prob), x), exponent_step, local_state; kws... + x -> optimal_map(operator(prob), x), exponent_step, local_state; kws... ) - if nsites==1 + if nsites == 1 curr_reg = current_region(region_iterator) next_reg = next_region(region_iterator) if !isnothing(next_reg) && next_reg != curr_reg @@ -46,31 +46,32 @@ function update( v1, v2 = src(next_edge), dst(next_edge) psi = copy(state(prob)) psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2])) - shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2)) - R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...) - local_state = psi[v1]*R_t + shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2)) + R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...) + local_state = psi[v1] * R_t end end - prob = set_current_exponent(prob, current_exponent(prob)+exponent_step) + prob = set_current_exponent(prob, current_exponent(prob) + exponent_step) return prob, local_state end -function sweep_callback( - problem::ApplyExpProblem; +function default_sweep_callback( + sweep_iterator::SweepIterator{<:ApplyExpProblem}; exponent_description="exponent", outputlevel, - sweep, - nsweeps, process_time=identity, - kws..., + kwargs..., ) if outputlevel >= 1 + the_problem = problem(sweep_iterator) @printf( - " Current %s = %s, ", exponent_description, process_time(current_exponent(problem)) + " Current %s = %s, ", + exponent_description, + process_time(current_exponent(the_problem)) ) - @printf("maxlinkdim=%d", maxlinkdim(state(problem))) + @printf("maxlinkdim=%d", maxlinkdim(state(the_problem))) println() flush(stdout) end @@ -88,9 +89,10 @@ function applyexp( kws..., ) exponent_steps = diff([zero(eltype(exponents)); exponents]) + # exponent_steps = diff(exponents) sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs) kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps] - sweep_iter = sweep_iterator(init_prob, kws_array) + sweep_iter = SweepIterator(init_prob, kws_array) converged_prob = sweep_solve(sweep_iter; outputlevel, kws...) return state(converged_prob) end @@ -111,11 +113,10 @@ function time_evolve( time_points, init_state; process_time=process_real_times, - sweep_callback=( - a...; k... - )->sweep_callback(a...; exponent_description="time", process_time, k...), + sweep_callback=(a...; k...) -> + default_sweep_callback(a...; exponent_description="time", process_time, k...), kws..., ) - exponents = [-im*t for t in time_points] + exponents = [-im * t for t in time_points] return applyexp(operator, exponents, init_state; sweep_callback, kws...) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 6916406a..0f4784f3 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -34,7 +34,7 @@ function update( solver=eigsolve_solver, kws..., ) - eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...) + eigval, local_state = solver(ψ -> optimal_map(operator(prob), ψ), local_state; kws...) prob = set_eigenvalue(prob, eigval) if outputlevel >= 2 @printf( @@ -44,12 +44,16 @@ function update( return prob, local_state end -function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...) +function default_sweep_callback( + sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel +) if outputlevel >= 1 - if nsweeps >= 10 - @printf("After sweep %02d/%d ", sweep, nsweeps) + nsweeps = length(sweep_iterator) + current_sweep = sweep_iterator.which_sweep + if length(sweep_iterator) >= 10 + @printf("After sweep %02d/%d ", current_sweep, nsweeps) else - @printf("After sweep %d/%d ", sweep, nsweeps) + @printf("After sweep %d/%d ", current_sweep, nsweeps) end @printf("eigenvalue=%.12f", eigenvalue(problem)) @printf(" maxlinkdim=%d", maxlinkdim(state(problem))) @@ -73,7 +77,7 @@ function eigsolve( init_prob = EigsolveProblem(; state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) - sweep_iter = sweep_iterator( + sweep_iter = SweepIterator( init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs ) prob = sweep_solve(sweep_iter; outputlevel, kws...) diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index e04df71d..be7a4e99 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -79,7 +79,7 @@ function fit_tensornetwork( insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false) common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs) kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps] - sweep_iter = sweep_iterator(init_prob, kwargs_array) + sweep_iter = SweepIterator(init_prob, kwargs_array) converged_prob = sweep_solve(sweep_iter; outputlevel, kws...) return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob)) end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index e082db56..9c3d7da8 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -1,61 +1,113 @@ +""" + abstract type AbstractNetworkIterator + +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins +with a call to `increment!` before executing `compute!`, however the initial call to +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that +this call is implict. Termination of the iterator is controlled by the function `done`. +""" +abstract type AbstractNetworkIterator end + +# We use greater than or equals here as we increment the state at the start of the iteration +done(NI::AbstractNetworkIterator) = state(NI) >= length(NI) + +function Base.iterate(NI::AbstractNetworkIterator, init=true) + done(NI) && return nothing + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* + # define a method for increment! This way we avoid cases where one may wish to nest + # calls to different step! methods accidentaly incrementing multiple times. + init || increment!(NI) + rv = compute!(NI) + return rv, false +end + +function increment! end +compute!(NI::AbstractNetworkIterator) = NI + +step!(NI::AbstractNetworkIterator) = step!(identity, NI) +function step!(f, NI::AbstractNetworkIterator) + compute!(NI) + f(NI) + increment!(NI) + return NI +end + # # RegionIterator # - -@kwdef mutable struct RegionIterator{Problem,RegionPlan} +""" + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator +""" +mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator problem::Problem region_plan::RegionPlan - which_region::Int = 1 + const sweep::Int + which_region::Int + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R} + return new{P,R}(problem, region_plan, sweep, 1) + end end +state(R::RegionIterator) = R.which_region +Base.length(R::RegionIterator) = length(R.region_plan) + problem(R::RegionIterator) = R.problem + current_region_plan(R::RegionIterator) = R.region_plan[R.which_region] -current_region(R::RegionIterator) = current_region_plan(R)[1] -region_kwargs(R::RegionIterator) = current_region_plan(R)[2] -function previous_region(R::RegionIterator) - return R.which_region == 1 ? nothing : R.region_plan[R.which_region - 1][1] + +function current_region(R::RegionIterator) + region, _ = current_region_plan(R) + return region end -function next_region(R::RegionIterator) - return if R.which_region == length(R.region_plan) - nothing - else - R.region_plan[R.which_region + 1][1] - end + +function current_region_kwargs(R::RegionIterator) + _, kwargs = current_region_plan(R) + return kwargs end -is_last_region(R::RegionIterator) = isnothing(next_region(R)) -function Base.iterate(R::RegionIterator, which=1) - R.which_region = which - region_plan_state = iterate(R.region_plan, which) - isnothing(region_plan_state) && return nothing - (current_region, region_kwargs), next = region_plan_state - R.problem = region_step(problem(R), R; region_kwargs...) - return R, next +function previous_region(R::RegionIterator) + state(R) <= 1 && return nothing + prev, _ = R.region_plan[R.which_region - 1] + return prev +end + +function next_region(R::RegionIterator) + is_last_region(R) && return nothing + next, _ = R.region_plan[R.which_region + 1] + return next end +is_last_region(R::RegionIterator) = length(R) === state(R) # # Functions associated with RegionIterator # -function region_iterator(problem; sweep_kwargs...) - return RegionIterator(; problem, region_plan=region_plan(problem; sweep_kwargs...)) +function compute!(R::RegionIterator) + region_kwargs = current_region_kwargs(R) + R.problem = region_step(R; region_kwargs...) + return R +end +function increment!(R::RegionIterator) + R.which_region += 1 + return R +end + +function RegionIterator(problem; sweep, sweep_kwargs...) + plan = region_plan(problem; sweep, sweep_kwargs...) + return RegionIterator(problem, plan, sweep) end function region_step( - problem, - region_iterator; - extract_kwargs=(;), - update_kwargs=(;), - insert_kwargs=(;), - sweep, - kws..., + region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws... ) - problem, local_state = extract(problem, region_iterator; extract_kwargs..., sweep, kws...) - problem, local_state = update( - problem, local_state, region_iterator; update_kwargs..., kws... - ) - problem = insert(problem, local_state, region_iterator; sweep, insert_kwargs..., kws...) - return problem + prob = problem(region_iterator) + + sweep = region_iterator.sweep + + prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...) + prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...) + prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...) + return prob end function region_plan(problem; kws...) @@ -66,39 +118,41 @@ end # SweepIterator # -mutable struct SweepIterator{Problem} +mutable struct SweepIterator{Problem} <: AbstractNetworkIterator sweep_kws region_iter::RegionIterator{Problem} which_sweep::Int + function SweepIterator(problem, sweep_kws) + sweep_kws = Iterators.Stateful(sweep_kws) + first_kwargs, _ = Iterators.peel(sweep_kws) + region_iter = RegionIterator(problem; sweep=1, first_kwargs...) + return new{typeof(problem)}(sweep_kws, region_iter, 1) + end end -problem(S::SweepIterator) = problem(S.region_iter) - -Base.length(S::SweepIterator) = length(S.sweep_kws) +done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws)) -function Base.iterate(S::SweepIterator, which=nothing) - if isnothing(which) - sweep_kws_state = iterate(S.sweep_kws) - else - sweep_kws_state = iterate(S.sweep_kws, which) - end - isnothing(sweep_kws_state) && return nothing - current_sweep_kws, next = sweep_kws_state +region_iterator(S::SweepIterator) = S.region_iter +problem(S::SweepIterator) = problem(region_iterator(S)) - if !isnothing(which) - S.region_iter = region_iterator( - problem(S.region_iter); sweep=S.which_sweep, current_sweep_kws... - ) - end - S.which_sweep += 1 - return S.region_iter, next +state(SR::SweepIterator) = SR.which_sweep +Base.length(S::SweepIterator) = length(S.sweep_kws) +function increment!(SR::SweepIterator) + SR.which_sweep += 1 + sweep_kwargs, _ = Iterators.peel(SR.sweep_kws) + SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...) + return SR end -function sweep_iterator(problem, sweep_kws) - region_iter = region_iterator(problem; sweep=1, first(sweep_kws)...) - return SweepIterator(sweep_kws, region_iter, 1) +function compute!(SR::SweepIterator) + for _ in SR.region_iter + # TODO: Is it sensible to execute the default region callback function? + end end -function sweep_iterator(problem, nsweeps::Integer; sweep_kws...) - return sweep_iterator(problem, Iterators.repeated(sweep_kws, nsweeps)) +# More basic constructor where sweep_kwargs are constant throughout sweeps +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) + # Initialize this to an empty RegionIterator + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) + return SweepIterator(problem, sweep_kwargs_iter) end diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index 3da97728..b1f082da 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -1,40 +1,30 @@ -region_callback(problem; kws...) = nothing - -function sweep_callback(problem; outputlevel, sweep, nsweeps, kws...) - if outputlevel >= 1 - println("Done with sweep $sweep/$nsweeps") - end +function default_region_callback(sweep_iterator; kwargs...) + return sweep_iterator end - +function default_sweep_callback(sweep_iterator; kwargs...) + return sweep_iterator +end +# In this implementation the function `sweep_solve` is essentially just a wrapper around +# the iterate interface that allows one to pass callbacks. function sweep_solve( sweep_iterator; + sweep_callback=default_sweep_callback, + region_callback=default_region_callback, outputlevel=0, - region_callback=region_callback, - sweep_callback=sweep_callback, - kwargs..., ) - for (sweep, region_iter) in enumerate(sweep_iterator) - for (region, region_kwargs) in region_tuples(region_iter) - region_callback( - problem(region_iter); - nsweeps=length(sweep_iterator), - outputlevel, - region_iterator=region_iter, - region, - region_kwargs, - sweep, - kwargs..., - ) + # Don't compute the region iteration automatically as we wish to insert a callback. + for _ in PauseAfterIncrement(sweep_iterator) + for _ in region_iterator(sweep_iterator) + region_callback(sweep_iterator; outputlevel=outputlevel) end - sweep_callback( - problem(region_iter); - nsweeps=length(sweep_iterator), - outputlevel, - region_iterator=region_iter, - sweep, - kwargs..., - ) + sweep_callback(sweep_iterator; outputlevel=outputlevel) end return problem(sweep_iterator) end + +# I suspect that `sweep_callback` is the more commonly used callback, so allow this to +# be set using the `do` syntax. +function sweep_solve(sweep_callback, sweep_iterator; kwargs...) + return sweep_solve(sweep_iterator; sweep_callback=sweep_callback, kwargs...) +end diff --git a/test/solvers/test_applyexp.jl b/test/solvers/test_applyexp.jl index acbd8318..c45a4817 100644 --- a/test/solvers/test_applyexp.jl +++ b/test/solvers/test_applyexp.jl @@ -104,8 +104,8 @@ end # Test that all time points are reached and reported correctly time_points = [0.0, 0.1, 0.25, 0.32, 0.4] times = Real[] - function collect_times(problem; kws...) - push!(times, ITensorNetworks.current_time(problem)) + function collect_times(sweep_iterator; kws...) + push!(times, ITensorNetworks.current_time(ITensorNetworks.problem(sweep_iterator))) end time_evolve(H, time_points, psi0; insert_kwargs, nsites, sweep_callback=collect_times, outputlevel=1) @test times ≈ time_points atol = 10 * eps(Float64) @@ -113,8 +113,8 @@ end # Test that all exponents are reached and reported correctly exponent_points = [-0.0, -0.1, -0.25, -0.32, -0.4] exponents = Real[] - function collect_exponents(problem; kws...) - push!(exponents, ITensorNetworks.current_exponent(problem)) + function collect_exponents(sweep_iterator; kws...) + push!(exponents, ITensorNetworks.current_exponent(ITensorNetworks.problem(sweep_iterator))) end applyexp(H, exponent_points, psi0; insert_kwargs, nsites, sweep_callback=collect_exponents, outputlevel=1) @test exponents ≈ exponent_points atol = 10 * eps(Float64) From c0ae5d090079b54897f26fbfa8cf44ed683ddc5e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 25 Sep 2025 15:07:06 -0400 Subject: [PATCH 04/55] Add `EachRegion` adapter that wraps `RegionIterator`, behaving the same but returning a tuple (region, kwargs) at each step --- src/solvers/adapters.jl | 55 +++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 44c70b20..92d70192 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -1,36 +1,3 @@ - -# -# TupleRegionIterator -# -# Adapts outputs to be (region, region_kwargs) tuples -# -# More generic design? maybe just assuming RegionIterator -# or its outputs implement some interface function that -# generates each tuple? -# - -mutable struct TupleRegionIterator{RegionIter} - region_iterator::RegionIter -end - -region_iterator(T::TupleRegionIterator) = T.region_iterator - -function Base.iterate(T::TupleRegionIterator, which=1) - state = iterate(region_iterator(T), which) - isnothing(state) && return nothing - (current_region, region_kwargs) = current_region_plan(region_iterator(T)) - return (current_region, region_kwargs), last(state) -end - -""" - region_tuples(R::RegionIterator) - -The `region_tuples` adapter converts a RegionIterator into an -iterator which outputs a tuple of the form (current_region, current_region_kwargs) -at each step. -""" -region_tuples(R::RegionIterator) = TupleRegionIterator(R) - """ struct PauseAfterIncrement{S<:AbstractNetworkIterator} @@ -48,3 +15,25 @@ increment!(NC::PauseAfterIncrement) = increment!(NC.parent) compute!(NC::PauseAfterIncrement) = NC PauseAfterIncrement(NC::PauseAfterIncrement) = NC + +""" + struct EachRegion{RegionIterator} <: AbstractNetworkIterator + +Wapper adapter that returns a tuple (region, kwargs) at each step rather than the iterator +itself. +""" +struct EachRegion{R<:RegionIterator} <: AbstractNetworkIterator + parent::R +end + +# Essential definitions +Base.length(ER::EachRegion) = length(ER.parent) +state(ER::EachRegion) = state(ER.parent) +increment!(ER::EachRegion) = state(ER.parent) + +function compute!(ER::EachRegion) + # Do the usual compute! for RegionIterator + compute!(ER.parent) + # But now lets return something useful + return current_region_plan(ER) +end From 3b9d0af35e3b5aa65d8e1b5cffc0cf3c1bd56d3c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 30 Sep 2025 10:20:09 -0400 Subject: [PATCH 05/55] Add unit tests for the `AbstractNetworkIterator` interface --- test/solvers/test_iterators.jl | 126 +++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 test/solvers/test_iterators.jl diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl new file mode 100644 index 00000000..2c7e0b4d --- /dev/null +++ b/test/solvers/test_iterators.jl @@ -0,0 +1,126 @@ +using Test: @test, @testset +using ITensorNetworks: done, state, increment!, compute! + +module TestIteratorUtils + +using ITensorNetworks + +mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator + state::Int + max::Int + output::Vector{Int} +end + +ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 +Base.length(TI::TestIterator) = TI.max +ITensorNetworks.state(TI::TestIterator) = TI.state +function ITensorNetworks.compute!(TI::TestIterator) + push!(TI.output, ITensorNetworks.state(TI)) + return TI +end + +mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator + parent::TestIterator +end + +Base.length(SA::SquareAdapter) = length(SA.parent) +ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) +ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) +function ITensorNetworks.compute!(SA::SquareAdapter) + ITensorNetworks.compute!(SA.parent) + return last(SA.parent.output)^2 +end + +end + +@testset "Iterators" begin + + using .TestIteratorUtils: TestIterator, SquareAdapter + + @testset "`AbstractNetworkIterator` Interface" begin + TI = TestIterator(1, 4, []) + + @test !done(TI) + + # First iterator should compute only + rv, st = iterate(TI) + @test !done(TI) + @test !st + @test rv === TI + @test length(TI.output) == 1 + @test only(TI.output) == 1 + @test state(TI) == 1 + @test !st + + rv, st = iterate(TI, st) + @test !done(TI) + @test !st + @test length(TI.output) == 2 + @test state(TI) == 2 + @test TI.output == [1, 2] + + increment!(TI) + @test !done(TI) + @test state(TI) == 3 + @test length(TI.output) == 2 + @test TI.output == [1, 2] + + compute!(TI) + @test !done(TI) + @test state(TI) == 3 + @test length(TI.output) == 3 + @test TI.output == [1, 2, 3] + + # Final Step + iterate(TI, false) + @test done(TI) + @test state(TI) == 4 + @test length(TI.output) == 4 + @test TI.output == [1, 2, 3, 4] + + @test iterate(TI, false) === nothing + + TI = TestIterator(1, 5, []) + + cb = [] + + for _ in TI + @test length(cb) == length(TI.output) - 1 + @test cb == (TI.output)[1:end-1] + push!(cb, state(TI)) + @test cb == TI.output + end + + @test done(TI) + @test length(TI.output) == 5 + @test length(cb) == 5 + @test cb == TI.output + + + TI = TestIterator(1, 5, []) + end + + @testset "Adapters" begin + TI = TestIterator(1, 5, []) + SA = SquareAdapter(TI) + + i = 0 + for rv in SA + i += 1 + @test rv isa Int + @test rv == i^2 + @test state(SA) == i + end + + @test done(SA) + + TI = TestIterator(1, 5, []) + SA = SquareAdapter(TI) + + SA_c = collect(SA) + + @test SA_c isa Vector + @test length(SA_c) == 5 + @test SA_c == [1, 4, 9, 16, 25] + end +end From 4ef4e75aea99e342bfa3d60fa8e8778797f3b1c2 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 30 Sep 2025 10:25:48 -0400 Subject: [PATCH 06/55] Rename `done` to `laststep` to better reflect the when it evalutes to true during the iteration --- src/solvers/adapters.jl | 2 +- src/solvers/iterators.jl | 6 +++--- test/solvers/test_iterators.jl | 18 +++++++++--------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 92d70192..f154f5d0 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -9,7 +9,7 @@ struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterato parent::S end -done(NC::PauseAfterIncrement) = done(NC.parent) +laststep(NC::PauseAfterIncrement) = laststep(NC.parent) state(NC::PauseAfterIncrement) = state(NC.parent) increment!(NC::PauseAfterIncrement) = increment!(NC.parent) compute!(NC::PauseAfterIncrement) = NC diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 9c3d7da8..a9a0c79c 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -9,10 +9,10 @@ this call is implict. Termination of the iterator is controlled by the function abstract type AbstractNetworkIterator end # We use greater than or equals here as we increment the state at the start of the iteration -done(NI::AbstractNetworkIterator) = state(NI) >= length(NI) +laststep(NI::AbstractNetworkIterator) = state(NI) >= length(NI) function Base.iterate(NI::AbstractNetworkIterator, init=true) - done(NI) && return nothing + laststep(NI) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. @@ -130,7 +130,7 @@ mutable struct SweepIterator{Problem} <: AbstractNetworkIterator end end -done(SR::SweepIterator) = isnothing(peek(SR.sweep_kws)) +laststep(SR::SweepIterator) = isnothing(peek(SR.sweep_kws)) region_iterator(S::SweepIterator) = S.region_iter problem(S::SweepIterator) = problem(region_iterator(S)) diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl index 2c7e0b4d..a39722c1 100644 --- a/test/solvers/test_iterators.jl +++ b/test/solvers/test_iterators.jl @@ -1,5 +1,5 @@ using Test: @test, @testset -using ITensorNetworks: done, state, increment!, compute! +using ITensorNetworks: laststep, state, increment!, compute! module TestIteratorUtils @@ -40,11 +40,11 @@ end @testset "`AbstractNetworkIterator` Interface" begin TI = TestIterator(1, 4, []) - @test !done(TI) + @test !laststep((TI)) # First iterator should compute only rv, st = iterate(TI) - @test !done(TI) + @test !laststep((TI)) @test !st @test rv === TI @test length(TI.output) == 1 @@ -53,27 +53,27 @@ end @test !st rv, st = iterate(TI, st) - @test !done(TI) + @test !laststep((TI)) @test !st @test length(TI.output) == 2 @test state(TI) == 2 @test TI.output == [1, 2] increment!(TI) - @test !done(TI) + @test !laststep((TI)) @test state(TI) == 3 @test length(TI.output) == 2 @test TI.output == [1, 2] compute!(TI) - @test !done(TI) + @test !laststep((TI)) @test state(TI) == 3 @test length(TI.output) == 3 @test TI.output == [1, 2, 3] # Final Step iterate(TI, false) - @test done(TI) + @test laststep((TI)) @test state(TI) == 4 @test length(TI.output) == 4 @test TI.output == [1, 2, 3, 4] @@ -91,7 +91,7 @@ end @test cb == TI.output end - @test done(TI) + @test laststep((TI)) @test length(TI.output) == 5 @test length(cb) == 5 @test cb == TI.output @@ -112,7 +112,7 @@ end @test state(SA) == i end - @test done(SA) + @test laststep((SA)) TI = TestIterator(1, 5, []) SA = SquareAdapter(TI) From e112eb4fb20a3ee622cb74f1ade32f8b794bcc4d Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 30 Sep 2025 10:28:15 -0400 Subject: [PATCH 07/55] Rename `previous_region` to `prev_region` to better align with julia `prev`/`next` naming convention --- src/solvers/iterators.jl | 2 +- src/solvers/subspace/ortho_subspace.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index a9a0c79c..05813a56 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -65,7 +65,7 @@ function current_region_kwargs(R::RegionIterator) return kwargs end -function previous_region(R::RegionIterator) +function prev_region(R::RegionIterator) state(R) <= 1 && return nothing prev, _ = R.region_plan[R.which_region - 1] return prev diff --git a/src/solvers/subspace/ortho_subspace.jl b/src/solvers/subspace/ortho_subspace.jl index 7d7ca6c2..26465309 100644 --- a/src/solvers/subspace/ortho_subspace.jl +++ b/src/solvers/subspace/ortho_subspace.jl @@ -28,7 +28,7 @@ function subspace_expand!( max_expand=default_max_expand(), kws..., ) - prev_region = previous_region(region_iterator) + prev_region = prev_region(region_iterator) region = current_region(region_iterator) if isnothing(prev_region) || isa(region, AbstractEdge) return local_tensor From da360e078be24eaef68ea3a3339a2241e29b8bd6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 1 Oct 2025 09:39:03 -0400 Subject: [PATCH 08/55] Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some variable names. --- src/solvers/adapters.jl | 24 +++++----- src/solvers/iterators.jl | 96 ++++++++++++++++++++------------------ src/solvers/sweep_solve.jl | 4 +- 3 files changed, 64 insertions(+), 60 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index f154f5d0..c1d139d0 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -5,16 +5,16 @@ Iterator wrapper whos `compute!` function simply returns itself, doing nothing i process. This allows one to manually call a custom `compute!` or insert their own code it in the loop body in place of `compute!`. """ -struct PauseAfterIncrement{S<:AbstractNetworkIterator} <: AbstractNetworkIterator +struct NoComputeStep{S<:AbstractNetworkIterator} <: AbstractNetworkIterator parent::S end -laststep(NC::PauseAfterIncrement) = laststep(NC.parent) -state(NC::PauseAfterIncrement) = state(NC.parent) -increment!(NC::PauseAfterIncrement) = increment!(NC.parent) -compute!(NC::PauseAfterIncrement) = NC +laststep(adapter::NoComputeStep) = laststep(adapter.parent) +state(adapter::NoComputeStep) = state(adapter.parent) +increment!(adapter::NoComputeStep) = increment!(adapter.parent) +compute!(adapter::NoComputeStep) = adapter -PauseAfterIncrement(NC::PauseAfterIncrement) = NC +NoComputeStep(adapter::NoComputeStep) = adapter """ struct EachRegion{RegionIterator} <: AbstractNetworkIterator @@ -27,13 +27,13 @@ struct EachRegion{R<:RegionIterator} <: AbstractNetworkIterator end # Essential definitions -Base.length(ER::EachRegion) = length(ER.parent) -state(ER::EachRegion) = state(ER.parent) -increment!(ER::EachRegion) = state(ER.parent) +Base.length(adapter::EachRegion) = length(adapter.parent) +state(adapter::EachRegion) = state(adapter.parent) +increment!(adapter::EachRegion) = state(adapter.parent) -function compute!(ER::EachRegion) +function compute!(adapter::EachRegion) # Do the usual compute! for RegionIterator - compute!(ER.parent) + compute!(adapter.parent) # But now lets return something useful - return current_region_plan(ER) + return current_region_plan(adapter) end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 05813a56..71dc1162 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -9,27 +9,27 @@ this call is implict. Termination of the iterator is controlled by the function abstract type AbstractNetworkIterator end # We use greater than or equals here as we increment the state at the start of the iteration -laststep(NI::AbstractNetworkIterator) = state(NI) >= length(NI) +laststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) -function Base.iterate(NI::AbstractNetworkIterator, init=true) - laststep(NI) && return nothing +function Base.iterate(iterator::AbstractNetworkIterator, init=true) + laststep(iterator) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. - init || increment!(NI) - rv = compute!(NI) + init || increment!(iterator) + rv = compute!(iterator) return rv, false end function increment! end -compute!(NI::AbstractNetworkIterator) = NI +compute!(iterator::AbstractNetworkIterator) = iterator -step!(NI::AbstractNetworkIterator) = step!(identity, NI) -function step!(f, NI::AbstractNetworkIterator) - compute!(NI) - f(NI) - increment!(NI) - return NI +step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) +function step!(f, iterator::AbstractNetworkIterator) + compute!(iterator) + f(iterator) + increment!(iterator) + return iterator end # @@ -48,48 +48,50 @@ mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator end end -state(R::RegionIterator) = R.which_region -Base.length(R::RegionIterator) = length(R.region_plan) +state(region_iter::RegionIterator) = region_iter.which_region +Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) -problem(R::RegionIterator) = R.problem +problem(region_iter::RegionIterator) = region_iter.problem -current_region_plan(R::RegionIterator) = R.region_plan[R.which_region] +function current_region_plan(region_iter::RegionIterator) + return region_iter.region_plan[region_iter.which_region] +end -function current_region(R::RegionIterator) - region, _ = current_region_plan(R) +function current_region(region_iter::RegionIterator) + region, _ = current_region_plan(region_iter) return region end -function current_region_kwargs(R::RegionIterator) - _, kwargs = current_region_plan(R) +function current_region_kwargs(region_iter::RegionIterator) + _, kwargs = current_region_plan(region_iter) return kwargs end -function prev_region(R::RegionIterator) - state(R) <= 1 && return nothing - prev, _ = R.region_plan[R.which_region - 1] +function prev_region(region_iter::RegionIterator) + state(region_iter) <= 1 && return nothing + prev, _ = region_iter.region_plan[region_iter.which_region - 1] return prev end -function next_region(R::RegionIterator) - is_last_region(R) && return nothing - next, _ = R.region_plan[R.which_region + 1] +function next_region(region_iter::RegionIterator) + is_last_region(region_iter) && return nothing + next, _ = region_iter.region_plan[region_iter.which_region + 1] return next end -is_last_region(R::RegionIterator) = length(R) === state(R) +is_last_region(region_iter::RegionIterator) = length(region_iter) === state(region_iter) # # Functions associated with RegionIterator # -function compute!(R::RegionIterator) - region_kwargs = current_region_kwargs(R) - R.problem = region_step(R; region_kwargs...) - return R +function compute!(region_iter::RegionIterator) + region_kwargs = current_region_kwargs(region_iter) + region_iter.problem = region_step(region_iter; region_kwargs...) + return region_iter end -function increment!(R::RegionIterator) - R.which_region += 1 - return R +function increment!(region_iter::RegionIterator) + region_iter.which_region += 1 + return region_iter end function RegionIterator(problem; sweep, sweep_kwargs...) @@ -130,22 +132,24 @@ mutable struct SweepIterator{Problem} <: AbstractNetworkIterator end end -laststep(SR::SweepIterator) = isnothing(peek(SR.sweep_kws)) +laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kws)) -region_iterator(S::SweepIterator) = S.region_iter -problem(S::SweepIterator) = problem(region_iterator(S)) +region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter +problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) -state(SR::SweepIterator) = SR.which_sweep -Base.length(S::SweepIterator) = length(S.sweep_kws) -function increment!(SR::SweepIterator) - SR.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(SR.sweep_kws) - SR.region_iter = RegionIterator(problem(SR); sweep=state(SR), sweep_kwargs...) - return SR +state(sweep_iter::SweepIterator) = sweep_iter.which_sweep +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kws) +function increment!(sweep_iter::SweepIterator) + sweep_iter.which_sweep += 1 + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kws) + sweep_iter.region_iter = RegionIterator( + problem(sweep_iter); sweep=state(sweep_iter), sweep_kwargs... + ) + return sweep_iter end -function compute!(SR::SweepIterator) - for _ in SR.region_iter +function compute!(sweep_iter::SweepIterator) + for _ in sweep_iter.region_iter # TODO: Is it sensible to execute the default region callback function? end end diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index b1f082da..9273bad9 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -14,7 +14,7 @@ function sweep_solve( outputlevel=0, ) # Don't compute the region iteration automatically as we wish to insert a callback. - for _ in PauseAfterIncrement(sweep_iterator) + for _ in NoComputeStep(sweep_iterator) for _ in region_iterator(sweep_iterator) region_callback(sweep_iterator; outputlevel=outputlevel) end @@ -26,5 +26,5 @@ end # I suspect that `sweep_callback` is the more commonly used callback, so allow this to # be set using the `do` syntax. function sweep_solve(sweep_callback, sweep_iterator; kwargs...) - return sweep_solve(sweep_iterator; sweep_callback=sweep_callback, kwargs...) + return sweep_solve(sweep_iterator; sweep_callback, kwargs...) end From 8bfc4836128fa73aa3951beeb19bf66943dd7512 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 3 Oct 2025 14:34:10 -0400 Subject: [PATCH 09/55] Make `extract` and `subspace_expand` mutating --- src/solvers/extract.jl | 22 ++++++++------ src/solvers/fitting.jl | 33 +++++++++++++++++---- src/solvers/iterators.jl | 2 +- src/solvers/subspace/densitymatrix.jl | 41 ++++++++++++++------------- src/solvers/subspace/subspace.jl | 21 ++++++-------- 5 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index 011058af..66c5d19c 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -1,12 +1,18 @@ -function extract(problem, region_iterator; sweep, trunc=(;), kws...) +function extract!(region_iterator; sweep, trunc=(;), kws...) + prob = problem(region_iterator) + trunc = truncation_parameters(sweep; trunc...) region = current_region(region_iterator) - psi = orthogonalize(state(problem), region) + psi = orthogonalize(state(prob), region) local_state = prod(psi[v] for v in region) - problem = set_state(problem, psi) - problem, local_state = subspace_expand( - problem, local_state, region_iterator; sweep, trunc, kws... - ) - shifted_operator = position(operator(problem), state(problem), region) - return set_operator(problem, shifted_operator), local_state + + prob.state = psi + + local_state = subspace_expand!(local_state, region_iterator; sweep, trunc, kws...) + + shifted_operator = position(operator(prob), state(prob), region) + + prob.operator = shifted_operator + + return local_state end diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index be7a4e99..c5d4234d 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -26,11 +26,13 @@ function ket(F::FittingProblem) return first(induced_subgraph(tensornetwork(state(F)), ket_vertices)) end -function extract(problem::FittingProblem, region_iterator; sweep, kws...) - region = current_region(region_iterator) - prev_region = gauge_region(problem) - tn = state(problem) - path = edge_sequence_between_regions(ket_graph(problem), prev_region, region) +function extract!(region_iter::RegionIterator{<:FittingProblem}; sweep, kws...) + prob = problem(region_iter) + + region = current_region(region_iter) + prev_region = gauge_region(prob) + tn = state(prob) + path = edge_sequence_between_regions(ket_graph(prob), prev_region, region) tn = gauge_walk(Algorithm("orthogonalize"), tn, path) pe_path = partitionedges(partitioned_tensornetwork(tn), path) tn = update( @@ -40,9 +42,28 @@ function extract(problem::FittingProblem, region_iterator; sweep, kws...) sequence = contraction_sequence(local_tensor; alg="optimal") local_tensor = dag(contract(local_tensor; sequence)) #problem, local_tensor = subspace_expand(problem, local_tensor, region; sweep, kws...) - return setproperties(problem; state=tn, gauge_region=region), local_tensor + + prob.state = tn + prob.gauge_region = region + + return local_tensor end +# function update( +# region_iter::RegionIterator{FittingProblem}, local_tensor, region; outputlevel, kws... +# ) +# F = problem(region_iter) +# +# region = current_region(F) +# +# n = (local_tensor * dag(local_tensor))[] +# F.overlap = n / sqrt(n) +# if outputlevel >= 2 +# @printf(" Region %s: squared overlap = %.12f\n", region, overlap(F)) +# end +# return F, local_tensor +# end + function update(F::FittingProblem, local_tensor, region; outputlevel, kws...) n = (local_tensor * dag(local_tensor))[] F = set_overlap(F, n / sqrt(n)) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 71dc1162..9e495c04 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -106,7 +106,7 @@ function region_step( sweep = region_iterator.sweep - prob, local_state = extract(prob, region_iterator; extract_kwargs..., sweep, kws...) + local_state = extract!(region_iterator; extract_kwargs..., sweep, kws...) prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...) prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...) return prob diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 0c4cfa69..5fc77a16 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -1,9 +1,8 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf -function subspace_expand( +function subspace_expand!( ::Backend"densitymatrix", - problem, local_state::ITensor, region_iterator; expansion_factor, @@ -12,59 +11,63 @@ function subspace_expand( trunc, kws..., ) + prob = problem(region_iterator) + region = current_region(region_iterator) - psi = copy(state(problem)) + psi = copy(state(prob)) - prev_vertex_set = setdiff(pos(operator(problem)), region) - (length(prev_vertex_set) != 1) && return problem, local_state + prev_vertex_set = setdiff(pos(operator(prob)), region) + (length(prev_vertex_set) != 1) && return local_state prev_vertex = only(prev_vertex_set) A = psi[prev_vertex] next_vertices = filter(v -> (hascommoninds(psi[v], A)), region) - isempty(next_vertices) && return problem, local_state + isempty(next_vertices) && return local_state next_vertex = only(next_vertices) C = psi[next_vertex] a = commonind(A, C) - isnothing(a) && return problem, local_state + isnothing(a) && return local_state basis_size = prod(dim.(uniqueinds(A, C))) expanded_maxdim = compute_expansion( dim(a), basis_size; expansion_factor, max_expand, trunc.maxdim ) - expanded_maxdim <= 0 && return problem, local_state + expanded_maxdim <= 0 && return local_state trunc = (; trunc..., maxdim=expanded_maxdim) - envs = environments(operator(problem)) - H = operator(operator(problem)) + envs = environments(operator(prob)) + H = operator(operator(prob)) sqrt_rho = A - for e in incident_edges(operator(problem)) + for e in incident_edges(operator(prob)) (src(e) ∈ region || dst(e) ∈ region) && continue sqrt_rho *= envs[e] end sqrt_rho *= H[prev_vertex] - conj_proj_A(T) = (T - prime(A)*(dag(prime(A))*T)) + conj_proj_A(T) = (T - prime(A) * (dag(prime(A)) * T)) for pass in 1:north_pass sqrt_rho = conj_proj_A(sqrt_rho) end rho = sqrt_rho * dag(noprime(sqrt_rho)) D, U = eigen(rho; trunc..., ishermitian=true) - Uproj(T) = (T - prime(A, a)*(dag(prime(A, a))*T)) + Uproj(T) = (T - prime(A, a) * (dag(prime(A, a)) * T)) for pass in 1:north_pass U = Uproj(U) end - if norm(dag(U)*A) > 1E-10 - @printf("Warning: |U*A| = %.3E in subspace expansion\n", norm(dag(U)*A)) - return problem, local_state + if norm(dag(U) * A) > 1E-10 + @printf("Warning: |U*A| = %.3E in subspace expansion\n", norm(dag(U) * A)) + return local_state end - Ax, ax = directsum(A=>a, U=>commonind(U, D)) + Ax, ax = directsum(A => a, U => commonind(U, D)) expander = dag(Ax) * A psi[prev_vertex] = Ax psi[next_vertex] = expander * C - local_state = expander*local_state + local_state = expander * local_state + + prob.state = psi - return set_state(problem, psi), local_state + return local_state end diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index 1c5dec87..f8549af1 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -4,8 +4,7 @@ using NDTensors.BackendSelection: Backend, @Backend_str default_expansion_factor() = 1.5 default_max_expand() = typemax(Int) -function subspace_expand( - problem, +function subspace_expand!( local_state, region_iterator; expansion_factor=default_expansion_factor(), @@ -17,9 +16,8 @@ function subspace_expand( ) expansion_factor = get_or_last(expansion_factor, sweep) max_expand = get_or_last(max_expand, sweep) - return subspace_expand( + local_state = subspace_expand!( Backend(subspace_algorithm), - problem, local_state, region_iterator; expansion_factor, @@ -27,18 +25,17 @@ function subspace_expand( trunc, kws..., ) + return local_state end -function subspace_expand(backend, problem, local_state, region_iterator; kws...) - error( +function subspace_expand!(backend, local_state, region_iterator; kws...) + return error( "Subspace expansion (subspace_expand!) not defined for requested combination of subspace_algorithm and problem types", ) end -function subspace_expand( - backend::Backend{:nothing}, problem, local_state, region_iterator; kws... -) - problem, local_state +function subspace_expand!(backend::Backend{:nothing}, local_state, region_iterator; kws...) + return local_state end function compute_expansion( @@ -55,9 +52,9 @@ function compute_expansion( expand_maxdim = min(max_expand, expand_maxdim) # Restrict expand_maxdim below theoretical upper limit - expand_maxdim = min(basis_size-current_dim, expand_maxdim) + expand_maxdim = min(basis_size - current_dim, expand_maxdim) # Enforce total maxdim setting (e.g. used in insert step) - expand_maxdim = min(maxdim-current_dim, expand_maxdim) + expand_maxdim = min(maxdim - current_dim, expand_maxdim) # Ensure expand_maxdim is non-negative expand_maxdim = max(0, expand_maxdim) return expand_maxdim From 1ef84981211d804ce926e2bfe2ca20aa17e62397 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 3 Oct 2025 14:43:52 -0400 Subject: [PATCH 10/55] Make `update` mutable --- src/solvers/applyexp.jl | 13 +++++++------ src/solvers/eigsolve.jl | 12 +++++++----- src/solvers/fitting.jl | 29 +++++++++++------------------ src/solvers/iterators.jl | 2 +- 4 files changed, 26 insertions(+), 30 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 144a77e6..c0f60bbe 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -23,17 +23,18 @@ function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...) return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...) end -function update( - prob::ApplyExpProblem, +function update!( local_state, - region_iterator; + region_iterator::RegionIterator{<:ApplyExpProblem}; nsites, exponent_step, solver=runge_kutta_solver, outputlevel, kws..., ) - iszero(abs(exponent_step)) && return prob, local_state + prob = problem(region_iterator) + + iszero(abs(exponent_step)) && return local_state local_state, info = solver( x -> optimal_map(operator(prob), x), exponent_step, local_state; kws... @@ -52,9 +53,9 @@ function update( end end - prob = set_current_exponent(prob, current_exponent(prob) + exponent_step) + prob.current_exponent += exponent_step - return prob, local_state + return local_state end function default_sweep_callback( diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 0f4784f3..f534ca22 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -26,22 +26,24 @@ function set_truncation_info(E::EigsolveProblem; spectrum=nothing) return E end -function update( - prob::EigsolveProblem, +function update!( local_state, - region_iterator; + region_iterator::RegionIterator{<:EigsolveProblem}; outputlevel, solver=eigsolve_solver, kws..., ) + prob = problem(region_iterator) + eigval, local_state = solver(ψ -> optimal_map(operator(prob), ψ), local_state; kws...) - prob = set_eigenvalue(prob, eigval) + prob.eigenvalue = eigval + if outputlevel >= 2 @printf( " Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob) ) end - return prob, local_state + return local_state end function default_sweep_callback( diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index c5d4234d..844f8a61 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -49,28 +49,21 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}; sweep, kws...) return local_tensor end -# function update( -# region_iter::RegionIterator{FittingProblem}, local_tensor, region; outputlevel, kws... -# ) -# F = problem(region_iter) -# -# region = current_region(F) -# -# n = (local_tensor * dag(local_tensor))[] -# F.overlap = n / sqrt(n) -# if outputlevel >= 2 -# @printf(" Region %s: squared overlap = %.12f\n", region, overlap(F)) -# end -# return F, local_tensor -# end - -function update(F::FittingProblem, local_tensor, region; outputlevel, kws...) +function update!( + local_tensor, region_iter::RegionIterator{<:FittingProblem}; outputlevel, kws... +) + F = problem(region_iter) + + region = current_region(region_iter) + n = (local_tensor * dag(local_tensor))[] - F = set_overlap(F, n / sqrt(n)) + F.overlap = n / sqrt(n) + if outputlevel >= 2 @printf(" Region %s: squared overlap = %.12f\n", region, overlap(F)) end - return F, local_tensor + + return local_tensor end function region_plan(F::FittingProblem; nsites, sweep_kwargs...) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 9e495c04..cef70aa9 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -107,7 +107,7 @@ function region_step( sweep = region_iterator.sweep local_state = extract!(region_iterator; extract_kwargs..., sweep, kws...) - prob, local_state = update(prob, local_state, region_iterator; update_kwargs..., kws...) + local_state = update!(local_state, region_iterator; update_kwargs..., kws...) prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...) return prob end From 0a6e891664944262b88cd5c1652a8607cf587838 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 3 Oct 2025 14:47:31 -0400 Subject: [PATCH 11/55] Make `insert` mutable --- src/solvers/abstract_problem.jl | 2 +- src/solvers/eigsolve.jl | 4 ++-- src/solvers/insert.jl | 15 +++++++++------ src/solvers/iterators.jl | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/solvers/abstract_problem.jl b/src/solvers/abstract_problem.jl index e3bde03b..6e629090 100644 --- a/src/solvers/abstract_problem.jl +++ b/src/solvers/abstract_problem.jl @@ -1,4 +1,4 @@ abstract type AbstractProblem end -set_truncation_info(P::AbstractProblem, args...; kws...) = P +set_truncation_info!(P::AbstractProblem, args...; kws...) = P diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index f534ca22..6bdafa03 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -19,9 +19,9 @@ set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue set_state(E::EigsolveProblem, state) = (@set E.state = state) set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror) -function set_truncation_info(E::EigsolveProblem; spectrum=nothing) +function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) if !isnothing(spectrum) - E = set_max_truncerror(E, max(max_truncerror(E), truncerror(spectrum))) + E.max_truncerror = max(max_truncerror(E), truncerror(spectrum)) end return E end diff --git a/src/solvers/insert.jl b/src/solvers/insert.jl index b3c60645..b71fb05a 100644 --- a/src/solvers/insert.jl +++ b/src/solvers/insert.jl @@ -1,7 +1,6 @@ using NamedGraphs: edgetype -function insert( - problem, +function insert!( local_tensor, region_iterator; normalize=false, @@ -11,9 +10,11 @@ function insert( outputlevel=0, kws..., ) + prob = problem(region_iterator) + trunc = truncation_parameters(sweep; trunc...) region = current_region(region_iterator) - psi = copy(state(problem)) + psi = copy(state(prob)) if length(region) == 1 C = local_tensor elseif length(region) == 2 @@ -22,7 +23,7 @@ function insert( tags = ITensors.tags(psi, e) U, C, spectrum = factorize(local_tensor, indsTe; tags, trunc...) @preserve_graph psi[first(region)] = U - problem = set_truncation_info(problem; spectrum) + prob = set_truncation_info!(prob; spectrum) else error("Region of length $(length(region)) not currently supported") end @@ -30,6 +31,8 @@ function insert( @preserve_graph psi[v] = C psi = set_orthogonal_region ? set_ortho_region(psi, [v]) : psi normalize && @preserve_graph psi[v] = psi[v] / norm(psi[v]) - problem = set_state(problem, psi) - return problem + + prob.state = psi + + return prob end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index cef70aa9..6f4be38b 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -108,7 +108,7 @@ function region_step( local_state = extract!(region_iterator; extract_kwargs..., sweep, kws...) local_state = update!(local_state, region_iterator; update_kwargs..., kws...) - prob = insert(prob, local_state, region_iterator; sweep, insert_kwargs..., kws...) + prob = insert!(local_state, region_iterator; sweep, insert_kwargs..., kws...) return prob end From 0653c475fc53fa7b04df8259b0865c4cc4e2df04 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 3 Oct 2025 16:35:56 -0400 Subject: [PATCH 12/55] First implementation of an `options` system. To be simplified... --- src/ITensorNetworks.jl | 1 + src/solvers/options.jl | 135 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 src/solvers/options.jl diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 339e900e..7900e05a 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -66,6 +66,7 @@ include("solvers/abstract_problem.jl") include("solvers/eigsolve.jl") include("solvers/applyexp.jl") include("solvers/fitting.jl") +include("solvers/options.jl") include("apply.jl") include("inner.jl") diff --git a/src/solvers/options.jl b/src/solvers/options.jl new file mode 100644 index 00000000..bd823810 --- /dev/null +++ b/src/solvers/options.jl @@ -0,0 +1,135 @@ +defaults() = defaults(Any) +defaults(::Any) = (;) + +user_defaults() = user_defaults(Any) +user_defaults(::Any) = (;) + +# The default case is no defaults exposed at all; they are hardcoded as keyword arguments +# in the function. +# +# defaults(::AbstractProblem, ::Function), in reality, but dont type so one can specialize. +defaults(::Any, ::Any) = (;) +user_defaults(::Any, ::Any) = (;) + +""" +Use this function to get options (keyword arguments) from the `RegionIterator` object. +For now we ignore the possibilty of having option packs NOT tied to functions. +""" +function getoption(region_iter::RegionIterator, name=nothing) + # Get the current specific options for the region + opt = current_region_kwargs(region_iter) + prob = problem(region_iter) # We use this to dispatch different defaults + + if isnothing(name) + # If no `name` then just return the "global" defaults overridden by whatever is in `opt` + # as a NamedTuple + return merge(defaults(prob), user_defaults(prob), opt) + elseif name isa Symbol + # If `name isa Symbol`, then this refers to a specific global option, so expand global + # defaults (with overwrites from `opt`) and return this field. + return getfield(getoption(region_iter), name) + elseif name isa Function + # If `name` is a Function, then this refers to a set of options tied to the function + # `name`, we should expand these defaults, override with the `opt.name` and then return + # the NamedTuple that results. + + default_opt = defaults(prob, name) + user_default_opt = user_defaults(prob, name) + region_opt = get(opt, Symbol(name), (;)) + + return merge(default_opt, user_default_opt, region_opt) + end +end + +function expand_defaults(f, region_iter::RegionIterator) + opt = current_region_kwargs(region_iter) + prob = problem(region_iter) + + return merge(default_kwargs(f, prob), get(opt, Symbol(name), (;))) +end + +function getoption(region_iter::RegionIterator, func::Function, name::Symbol) + # Returning a specific option of a the options of `func`. + return getfield(getoption(region_iter, func), name) +end + +#= + +# Example: + +struct MyProblem <: AbstractProblem end + +# We have to set the "global" defaults, (if we want to use any), as there is no notion +# of a function where they can be set. If they are in the region plan then that will be used, +# but without defaults set you would have to always have `verbosity` (say) in the region plan +defaults(::MyProblem) = (; verbosity=0) + +function compute!(iter::RegionIterator{MyProblem}) + # By default, `getoption` will just splat whatever the region plan opts are! + extract!(iter; getoption(iter, extract!)...) + error = update!(iter; getoption(iter, update!)...) + + # This _will_ error if `verbosity` is not defined by `defaults`. + if getoption(iter, :verbosity) > 0 + @info "Error: $error" + end + + return iter +end + +# Now lets customize the `update!` function for our specific type. Let suppose we are +# just quickly prototyping and do not care about sharing code and setting defaults etc, we +# can still just use normal keyword arguments. +# +# The return value of `defaults(problem, update!)` overwrites these hard-coded values, but +# by default `defaults(::AbstractProblem, ::Function) = (;)` so overwrites nothing. +function update!(iter::RegionIterator{MyProblem}; maxiter=100, normalize=true) + total_error = 0 + + for _ in 1:maxiter + state, error = truncation(iter; getoption(iter, truncate)...) + total_error += error + if normalize + state = state / norm(state) + end + end + + return total_error +end + +# e.g. ... +truncation(iter; kwargs...) = rand(2, 2), 1 +extract!(iter; kwargs...) = nothing + +# If you now want to share these defaults, then you should define the following: +function defaults(::MyProblem, ::typeof(update!)) + # These will overwrite the keyword defaults. You may want to remove the keyword defaults + # to remove any ambiguity i.e. `function update!(...; maxiter, norm) ...` + return (; maxiter=200, normalize=true) +end + +# If you want a user to be able to override these defaults with their own defaults (without +# introducing an abstract type) we need another function (this would be set by the user.) +function user_defaults(::MyProblem, ::typeof(update!)) + # This only overwrites the specified default. + return (; normalize=false) +end + +# So, in order of priority, the options get chosen like +# - whatever the options from the region plan are +# - whatever is in `user_defaults` +# - whatever is in `defaults` +# - whatever the keyword argument is set to (if anything). +# +# The `NamedTuple`s in the region plan only need to have one layer of nesting, i.e. the +# "global options" (if any) and the function option packs. + +function test() + ri = RegionIterator( + MyProblem(), ["region" => ((update!)=(; maxiter=300), verbosity=1)], 1 + ) + compute!(ri) + return nothing +end + +=# From d77321e526121726f65becd6ef2afc975b2fdf96 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 6 Oct 2025 10:14:59 -0400 Subject: [PATCH 13/55] Simplify options interface to a single function `default_kwargs`. --- src/solvers/options.jl | 136 ++--------------------------------------- 1 file changed, 6 insertions(+), 130 deletions(-) diff --git a/src/solvers/options.jl b/src/solvers/options.jl index bd823810..31825731 100644 --- a/src/solvers/options.jl +++ b/src/solvers/options.jl @@ -1,135 +1,11 @@ -defaults() = defaults(Any) -defaults(::Any) = (;) - -user_defaults() = user_defaults(Any) -user_defaults(::Any) = (;) - -# The default case is no defaults exposed at all; they are hardcoded as keyword arguments -# in the function. -# -# defaults(::AbstractProblem, ::Function), in reality, but dont type so one can specialize. defaults(::Any, ::Any) = (;) -user_defaults(::Any, ::Any) = (;) """ -Use this function to get options (keyword arguments) from the `RegionIterator` object. -For now we ignore the possibilty of having option packs NOT tied to functions. -""" -function getoption(region_iter::RegionIterator, name=nothing) - # Get the current specific options for the region - opt = current_region_kwargs(region_iter) - prob = problem(region_iter) # We use this to dispatch different defaults - - if isnothing(name) - # If no `name` then just return the "global" defaults overridden by whatever is in `opt` - # as a NamedTuple - return merge(defaults(prob), user_defaults(prob), opt) - elseif name isa Symbol - # If `name isa Symbol`, then this refers to a specific global option, so expand global - # defaults (with overwrites from `opt`) and return this field. - return getfield(getoption(region_iter), name) - elseif name isa Function - # If `name` is a Function, then this refers to a set of options tied to the function - # `name`, we should expand these defaults, override with the `opt.name` and then return - # the NamedTuple that results. - - default_opt = defaults(prob, name) - user_default_opt = user_defaults(prob, name) - region_opt = get(opt, Symbol(name), (;)) - - return merge(default_opt, user_default_opt, region_opt) - end -end - -function expand_defaults(f, region_iter::RegionIterator) - opt = current_region_kwargs(region_iter) - prob = problem(region_iter) - - return merge(default_kwargs(f, prob), get(opt, Symbol(name), (;))) -end - -function getoption(region_iter::RegionIterator, func::Function, name::Symbol) - # Returning a specific option of a the options of `func`. - return getfield(getoption(region_iter, func), name) -end - -#= - -# Example: - -struct MyProblem <: AbstractProblem end - -# We have to set the "global" defaults, (if we want to use any), as there is no notion -# of a function where they can be set. If they are in the region plan then that will be used, -# but without defaults set you would have to always have `verbosity` (say) in the region plan -defaults(::MyProblem) = (; verbosity=0) + function default_kwargs(f, iter::RegionIterator) -function compute!(iter::RegionIterator{MyProblem}) - # By default, `getoption` will just splat whatever the region plan opts are! - extract!(iter; getoption(iter, extract!)...) - error = update!(iter; getoption(iter, update!)...) - - # This _will_ error if `verbosity` is not defined by `defaults`. - if getoption(iter, :verbosity) > 0 - @info "Error: $error" - end - - return iter -end - -# Now lets customize the `update!` function for our specific type. Let suppose we are -# just quickly prototyping and do not care about sharing code and setting defaults etc, we -# can still just use normal keyword arguments. -# -# The return value of `defaults(problem, update!)` overwrites these hard-coded values, but -# by default `defaults(::AbstractProblem, ::Function) = (;)` so overwrites nothing. -function update!(iter::RegionIterator{MyProblem}; maxiter=100, normalize=true) - total_error = 0 - - for _ in 1:maxiter - state, error = truncation(iter; getoption(iter, truncate)...) - total_error += error - if normalize - state = state / norm(state) - end - end - - return total_error -end - -# e.g. ... -truncation(iter; kwargs...) = rand(2, 2), 1 -extract!(iter; kwargs...) = nothing - -# If you now want to share these defaults, then you should define the following: -function defaults(::MyProblem, ::typeof(update!)) - # These will overwrite the keyword defaults. You may want to remove the keyword defaults - # to remove any ambiguity i.e. `function update!(...; maxiter, norm) ...` - return (; maxiter=200, normalize=true) -end - -# If you want a user to be able to override these defaults with their own defaults (without -# introducing an abstract type) we need another function (this would be set by the user.) -function user_defaults(::MyProblem, ::typeof(update!)) - # This only overwrites the specified default. - return (; normalize=false) -end - -# So, in order of priority, the options get chosen like -# - whatever the options from the region plan are -# - whatever is in `user_defaults` -# - whatever is in `defaults` -# - whatever the keyword argument is set to (if anything). -# -# The `NamedTuple`s in the region plan only need to have one layer of nesting, i.e. the -# "global options" (if any) and the function option packs. - -function test() - ri = RegionIterator( - MyProblem(), ["region" => ((update!)=(; maxiter=300), verbosity=1)], 1 - ) - compute!(ri) - return nothing +Return the default keyword arguments for the function `f` overridden by the contents of `iter`. +""" +function default_kwargs(f::Function, iter::RegionIterator) + region_kwargs = get(current_region_kwargs(iter), Symbol(f, :_kwargs), (;)) + return merge(default_kwargs(f, problem(iter)), region_kwargs) end - -=# From aff14c7ff595e701c267f1d7c30ca72c08222b2e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 6 Oct 2025 10:17:48 -0400 Subject: [PATCH 14/55] Put calls to `extract!` etc in `compute!` function directly This removes the function `region_step`. Default kwargs for these functions (none currently) are now splatted in using the `default_kwargs` function. Remove `sweep` kwargs as this can be obtained from the `RegionIterator`. --- src/solvers/extract.jl | 6 ++++-- src/solvers/iterators.jl | 27 ++++++++------------------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index 66c5d19c..493b3177 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -1,6 +1,8 @@ -function extract!(region_iterator; sweep, trunc=(;), kws...) +function extract!(region_iterator; trunc=(;)) prob = problem(region_iterator) + sweep = region_iterator.sweep + trunc = truncation_parameters(sweep; trunc...) region = current_region(region_iterator) psi = orthogonalize(state(prob), region) @@ -8,7 +10,7 @@ function extract!(region_iterator; sweep, trunc=(;), kws...) prob.state = psi - local_state = subspace_expand!(local_state, region_iterator; sweep, trunc, kws...) + local_state = subspace_expand!(local_state, region_iterator; sweep, trunc) shifted_operator = position(operator(prob), state(prob), region) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 6f4be38b..b5aa7938 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -83,35 +83,24 @@ is_last_region(region_iter::RegionIterator) = length(region_iter) === state(regi # # Functions associated with RegionIterator # - -function compute!(region_iter::RegionIterator) - region_kwargs = current_region_kwargs(region_iter) - region_iter.problem = region_step(region_iter; region_kwargs...) - return region_iter -end function increment!(region_iter::RegionIterator) region_iter.which_region += 1 return region_iter end +function compute!(iter::RegionIterator) + local_state = extract!(iter; default_kwargs(extract!, iter)...) + local_state = update!(local_state, iter; default_kwargs(update!, iter)...) + insert!(local_state, iter; default_kwargs(insert!, iter)...) + + return iter +end + function RegionIterator(problem; sweep, sweep_kwargs...) plan = region_plan(problem; sweep, sweep_kwargs...) return RegionIterator(problem, plan, sweep) end -function region_step( - region_iterator; extract_kwargs=(;), update_kwargs=(;), insert_kwargs=(;), kws... -) - prob = problem(region_iterator) - - sweep = region_iterator.sweep - - local_state = extract!(region_iterator; extract_kwargs..., sweep, kws...) - local_state = update!(local_state, region_iterator; update_kwargs..., kws...) - prob = insert!(local_state, region_iterator; sweep, insert_kwargs..., kws...) - return prob -end - function region_plan(problem; kws...) return euler_sweep(state(problem); kws...) end From 4b21cc948fb5c08879d57f61ade5caf937213255 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 13:58:18 -0400 Subject: [PATCH 15/55] Refactor the region plan generating code. Includes the code change `(region, kwargs)` to `region => kwargs` for readability, but I also think the `Pair` data structure is more appropriate here. Reversing the region now happens in seperate function. --- src/solvers/region_plans/dfs_plans.jl | 6 +- src/solvers/region_plans/euler_plans.jl | 4 +- src/solvers/region_plans/tdvp_region_plans.jl | 61 ++++++++++++------- 3 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/solvers/region_plans/dfs_plans.jl b/src/solvers/region_plans/dfs_plans.jl index 074fa94a..9b44b980 100644 --- a/src/solvers/region_plans/dfs_plans.jl +++ b/src/solvers/region_plans/dfs_plans.jl @@ -3,14 +3,14 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, post_order_dfs_edges, post_order_dfs_vertices function post_order_dfs_plan( - graph; nsites, root_vertex=default_root_vertex(graph), sweep_kwargs... + graph, sweep_kwargs; nsites, root_vertex=default_root_vertex(graph) ) if nsites == 1 vertices = post_order_dfs_vertices(graph, root_vertex) - fwd_sweep = [([v], sweep_kwargs) for v in vertices] + fwd_sweep = [[v] => sweep_kwargs for v in vertices] elseif nsites == 2 edges = post_order_dfs_edges(graph, root_vertex) - fwd_sweep = [([src(e), dst(e)], sweep_kwargs) for e in edges] + fwd_sweep = [[src(e), dst(e)] => sweep_kwargs for e in edges] end return fwd_sweep end diff --git a/src/solvers/region_plans/euler_plans.jl b/src/solvers/region_plans/euler_plans.jl index cf661d0d..6c5304f9 100644 --- a/src/solvers/region_plans/euler_plans.jl +++ b/src/solvers/region_plans/euler_plans.jl @@ -4,10 +4,10 @@ using NamedGraphs.GraphsExtensions: default_root_vertex function euler_sweep(graph; nsites, root_vertex=default_root_vertex(graph), sweep_kwargs...) if nsites == 1 vertices = euler_tour_vertices(graph, root_vertex) - sweep = [([v], sweep_kwargs) for v in vertices] + sweep = [[v] => sweep_kwargs for v in vertices] elseif nsites == 2 edges = euler_tour_edges(graph, root_vertex) - sweep = [([src(e), dst(e)], sweep_kwargs) for e in edges] + sweep = [[src(e), dst(e)] => sweep_kwargs for e in edges] end return sweep end diff --git a/src/solvers/region_plans/tdvp_region_plans.jl b/src/solvers/region_plans/tdvp_region_plans.jl index c03ad4eb..7b24211c 100644 --- a/src/solvers/region_plans/tdvp_region_plans.jl +++ b/src/solvers/region_plans/tdvp_region_plans.jl @@ -1,3 +1,5 @@ +using Accessors: @modify + function applyexp_sub_steps(order) if order == 1 return [1.0] @@ -5,40 +7,57 @@ function applyexp_sub_steps(order) return [1 / 2, 1 / 2] elseif order == 4 s = (2 - 2^(1 / 3))^(-1) - return [s/2, s/2, 1/2 - s, 1/2 - s, s/2, s/2] + return [s / 2, s / 2, 1 / 2 - s, 1 / 2 - s, s / 2, s / 2] else error("Applyexp order of $order not supported") end end -function first_order_sweep( - graph, exponent_step, dir=Base.Forward; update_kwargs, nsites, kws... -) - basic_fwd_sweep = post_order_dfs_plan(graph; nsites, kws...) - update_kwargs = (; nsites, exponent_step, update_kwargs...) - sweep = [] - for (j, (region, region_kws)) in enumerate(basic_fwd_sweep) - push!(sweep, (region, (; nsites, update_kwargs, region_kws...))) +function first_order_sweep(graph, sweep_kwargs; nsites) + basic_fwd_sweep = post_order_dfs_plan(graph, sweep_kwargs; nsites) + region_plan = [] + + for (j, (region, region_kwargs)) in enumerate(basic_fwd_sweep) + push!(region_plan, region => region_kwargs) + if length(region) == 2 && j < length(basic_fwd_sweep) - rev_kwargs = (; update_kwargs..., exponent_step=(-update_kwargs.exponent_step)) - push!(sweep, ([last(region)], (; update_kwargs=rev_kwargs, region_kws...))) + region_kwargs = @modify(-, region_kwargs.update!_kwargs.exponent_step) + push!(region_plan, [last(region)] => region_kwargs) end end - if dir==Base.Reverse - # Reverse regions as well as ordering of regions - sweep = [(reverse(reg_kws[1]), reg_kws[2]) for reg_kws in reverse(sweep)] + + return region_plan +end + +function reverse_regions(region_plan) + region_plan = map(reverse(region_plan)) do region_kwargs + region, kwargs = region_kwargs + return reverse(region) => kwargs end - return sweep + + return region_plan end -function applyexp_regions(graph, exponent_step; update_kwargs, order, nsites, kws...) +# Generate the kwargs for each region. +function applyexp_regions( + graph, raw_exponent_step; order, nsites, update!_kwargs=(; nsites), remaining_kwargs... +) sweep_plan = [] + for (step, weight) in enumerate(applyexp_sub_steps(order)) - dir = isodd(step) ? Base.Forward : Base.Reverse - append!( - sweep_plan, - first_order_sweep(graph, weight*exponent_step, dir; update_kwargs, nsites, kws...), - ) + # Use this exponent step only if none provided + new_update!_kwargs = (; exponent_step=weight * raw_exponent_step, update!_kwargs...) + + sweep_kwargs = (; remaining_kwargs..., update!_kwargs=new_update!_kwargs) + + region_plan = first_order_sweep(graph, sweep_kwargs; nsites) + + if iseven(step) + region_plan = reverse_regions(region_plan) + end + + append!(sweep_plan, region_plan) end + return sweep_plan end From e71512fb1e1130ba51b7db0a6afe281f7300bea9 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 13:59:05 -0400 Subject: [PATCH 16/55] Have `dmrg` take a strict number of arguments --- src/solvers/eigsolve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 6bdafa03..f12b1a3c 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -86,4 +86,4 @@ function eigsolve( return eigenvalue(prob), state(prob) end -dmrg(args...; kws...) = eigsolve(args...; kws...) +dmrg(operator, init_state; kwargs...) = eigsolve(operator, init_state; kwargs...) From a4ce308be123abdde960b9f2f8368d1579c8500e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 14:11:06 -0400 Subject: [PATCH 17/55] Purge non-mutating field setter functions. These are no longer necessary. --- src/solvers/applyexp.jl | 5 ----- src/solvers/eigsolve.jl | 6 ------ src/solvers/fitting.jl | 5 ----- 3 files changed, 16 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index c0f60bbe..367d61b7 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -1,5 +1,4 @@ using Printf: @printf -using Accessors: @set @kwdef mutable struct ApplyExpProblem{State} <: AbstractProblem operator @@ -15,10 +14,6 @@ function current_time(A::ApplyExpProblem) return iszero(imag(t)) ? real(t) : t end -set_operator(A::ApplyExpProblem, operator) = (@set A.operator = operator) -set_state(A::ApplyExpProblem, state) = (@set A.state = state) -set_current_exponent(A::ApplyExpProblem, exponent) = (@set A.current_exponent = exponent) - function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...) return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index f12b1a3c..0b253a92 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -1,4 +1,3 @@ -using Accessors: @set using Printf: @printf using ITensors: truncerror @@ -14,11 +13,6 @@ state(E::EigsolveProblem) = E.state operator(E::EigsolveProblem) = E.operator max_truncerror(E::EigsolveProblem) = E.max_truncerror -set_operator(E::EigsolveProblem, operator) = (@set E.operator = operator) -set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue) -set_state(E::EigsolveProblem, state) = (@set E.state = state) -set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror) - function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) if !isnothing(spectrum) E.max_truncerror = max(max_truncerror(E), truncerror(spectrum)) diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 844f8a61..667afb7d 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -1,9 +1,7 @@ -using Accessors: @set using Graphs: vertices using NamedGraphs: AbstractNamedGraph, NamedEdge using NamedGraphs.PartitionedGraphs: partitionedges using Printf: @printf -using ConstructionBase: setproperties @kwdef mutable struct FittingProblem{State<:AbstractBeliefPropagationCache} <: AbstractProblem @@ -18,9 +16,6 @@ ket_graph(F::FittingProblem) = F.ket_graph overlap(F::FittingProblem) = F.overlap gauge_region(F::FittingProblem) = F.gauge_region -set_state(F::FittingProblem, state) = (@set F.state = state) -set_overlap(F::FittingProblem, overlap) = (@set F.overlap = overlap) - function ket(F::FittingProblem) ket_vertices = vertices(ket_graph(F)) return first(induced_subgraph(tensornetwork(state(F)), ket_vertices)) From a8b2c51364b03055a93f5f1313a06d773ddcc027 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 15:14:18 -0400 Subject: [PATCH 18/55] Use `current_kwargs` for getting kwargs from `RegionIterator` --- src/solvers/iterators.jl | 6 +++--- src/solvers/options.jl | 46 +++++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index b5aa7938..77b86808 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -89,9 +89,9 @@ function increment!(region_iter::RegionIterator) end function compute!(iter::RegionIterator) - local_state = extract!(iter; default_kwargs(extract!, iter)...) - local_state = update!(local_state, iter; default_kwargs(update!, iter)...) - insert!(local_state, iter; default_kwargs(insert!, iter)...) + local_state = extract!(iter; current_kwargs(extract!, iter)...) + local_state = update!(local_state, iter; current_kwargs(update!, iter)...) + insert!(local_state, iter; current_kwargs(insert!, iter)...) return iter end diff --git a/src/solvers/options.jl b/src/solvers/options.jl index 31825731..528e8f0b 100644 --- a/src/solvers/options.jl +++ b/src/solvers/options.jl @@ -1,11 +1,47 @@ -defaults(::Any, ::Any) = (;) +""" + default_kwargs(f, [obj = Any]) + +Return the default keyword arguments for the function `f`. These defaults may be +derived from the contents or type of the second arugment `obj`. + +## Interface +Given a function `f`, one can optionally set the default keyword arguments for this +function by specializing either of the following two-argument methods: +``` +ITensorNetworks.default_kwargs(::typeof(f), prob::AbstractProblem) +ITensorNetworks.default_kwargs(::typeof(f), ::Type{<:AbstractProblem}) +``` +If one does not require the contents of `prob::Prob` to generate the defaults then it is +recommended to dispatch on `Type{<:Prob}` directly (second method) so the defaults +can be accessed without constructing an instance of a `Prob`. + +The return value of `default_kwargs` should be a `NamedTuple`, and will overwrite any +default values set in the function signature. """ - function default_kwargs(f, iter::RegionIterator) +default_kwargs(f) = default_kwargs(f, Any) +default_kwargs(f, obj) = _default_kwargs_fallback(f, obj) + +# To avoid annoying potential method ambiguities. +function _default_kwargs_fallback(f, iter::RegionIterator) + return default_kwargs(f, problem(iter)) +end +function _default_kwargs_fallback(f, problem::AbstractProblem) + return default_kwargs(f, typeof(problem)) +end + +# Eventually we reach this if nothing is specialized. +_default_kwargs_fallback(::Any, ::DataType) = (;) -Return the default keyword arguments for the function `f` overridden by the contents of `iter`. """ -function default_kwargs(f::Function, iter::RegionIterator) + current_kwargs(f, iter::RegionIterator) + +Return the keyword arguments to be passed to the function `f` for the current region +defined by the stateful iterator `iter`. +""" +function current_kwargs(f::Function, iter::RegionIterator) region_kwargs = get(current_region_kwargs(iter), Symbol(f, :_kwargs), (;)) - return merge(default_kwargs(f, problem(iter)), region_kwargs) + rv = merge(default_kwargs(f, iter), region_kwargs) + return rv +end end From 18a85037992108105865dc189c32c0e869e3b5e6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 15:16:12 -0400 Subject: [PATCH 19/55] Introduce defaults using `default_kwargs` and be stricter about which kwargs get passed to functions --- src/solvers/applyexp.jl | 33 ++--- src/solvers/eigsolve.jl | 35 ++--- src/solvers/extract.jl | 16 ++- src/solvers/fitting.jl | 34 ++--- src/solvers/insert.jl | 28 ++-- src/solvers/iterators.jl | 4 +- src/solvers/options.jl | 6 + src/solvers/subspace/densitymatrix.jl | 21 ++- src/solvers/subspace/subspace.jl | 59 ++++---- test/solvers/test_applyexp.jl | 189 +++++++++++++------------- test/solvers/test_eigsolve.jl | 13 +- 11 files changed, 211 insertions(+), 227 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 367d61b7..1cba9f44 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -14,8 +14,10 @@ function current_time(A::ApplyExpProblem) return iszero(imag(t)) ? real(t) : t end -function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...) - return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...) +# Rename region_plan +function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...) + # The `exponent_step` kwarg for the `update!` function needs some pre-processing. + return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...) end function update!( @@ -24,14 +26,13 @@ function update!( nsites, exponent_step, solver=runge_kutta_solver, - outputlevel, kws..., ) prob = problem(region_iterator) iszero(abs(exponent_step)) && return local_state - local_state, info = solver( + local_state, _ = solver( x -> optimal_map(operator(prob), x), exponent_step, local_state; kws... ) if nsites == 1 @@ -76,20 +77,20 @@ end function applyexp( init_prob::AbstractProblem, exponents; - extract_kwargs=(;), - update_kwargs=(;), - insert_kwargs=(;), - outputlevel=0, - nsites=1, + sweep_callback=default_sweep_callback, order=4, - kws..., + nsites=2, + sweep_kwargs..., ) exponent_steps = diff([zero(eltype(exponents)); exponents]) - # exponent_steps = diff(exponents) - sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs) - kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps] + + kws_array = [ + (; order, nsites, sweep_kwargs..., exponent_step) for exponent_step in exponent_steps + ] sweep_iter = SweepIterator(init_prob, kws_array) - converged_prob = sweep_solve(sweep_iter; outputlevel, kws...) + + converged_prob = sweep_solve(sweep_callback, sweep_iter; outputlevel=0) + return state(converged_prob) end @@ -111,8 +112,8 @@ function time_evolve( process_time=process_real_times, sweep_callback=(a...; k...) -> default_sweep_callback(a...; exponent_description="time", process_time, k...), - kws..., + sweep_kwargs..., ) exponents = [-im * t for t in time_points] - return applyexp(operator, exponents, init_state; sweep_callback, kws...) + return applyexp(operator, exponents, init_state; sweep_callback, sweep_kwargs...) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 0b253a92..2326ca04 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -21,15 +21,16 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) end function update!( - local_state, - region_iterator::RegionIterator{<:EigsolveProblem}; - outputlevel, - solver=eigsolve_solver, - kws..., + local_state, region_iterator::RegionIterator{<:EigsolveProblem}; outputlevel, solver ) prob = problem(region_iterator) - eigval, local_state = solver(ψ -> optimal_map(operator(prob), ψ), local_state; kws...) + eigval, local_state = solver( + ψ -> optimal_map(operator(prob), ψ), + local_state; + current_kwargs(solver, region_iterator)..., + ) + prob.eigenvalue = eigval if outputlevel >= 2 @@ -40,6 +41,10 @@ function update!( return local_state end +function default_kwargs(::typeof(update!), ::Type{<:EigsolveProblem}) + return (; outputlevel=0, solver=eigsolve_solver) +end + function default_sweep_callback( sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel ) @@ -59,24 +64,12 @@ function default_sweep_callback( end end -function eigsolve( - operator, - init_state; - nsweeps, - nsites=1, - outputlevel=0, - extract_kwargs=(;), - update_kwargs=(;), - insert_kwargs=(;), - kws..., -) +function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_kwargs...) init_prob = EigsolveProblem(; state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) - sweep_iter = SweepIterator( - init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs - ) - prob = sweep_solve(sweep_iter; outputlevel, kws...) + sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...) + prob = sweep_solve(sweep_iter; outputlevel) return eigenvalue(prob), state(prob) end diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index 493b3177..d87d3f88 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -1,17 +1,19 @@ -function extract!(region_iterator; trunc=(;)) - prob = problem(region_iterator) +function extract!(iter; kwargs...) + return _extract_fallback!(iter; subspace_algorithm="nothing", kwargs...) +end - sweep = region_iterator.sweep +# Internal function such that a method error can be thrown while still allowing a user +# to specialize on `extract!` +function _extract_fallback!(region_iter::RegionIterator; subspace_algorithm) + prob = problem(region_iter) + region = current_region(region_iter) - trunc = truncation_parameters(sweep; trunc...) - region = current_region(region_iterator) psi = orthogonalize(state(prob), region) local_state = prod(psi[v] for v in region) prob.state = psi - local_state = subspace_expand!(local_state, region_iterator; sweep, trunc) - + local_state = subspace_expand!(local_state, region_iter; subspace_algorithm) shifted_operator = position(operator(prob), state(prob), region) prob.operator = shifted_operator diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 667afb7d..0ec7e536 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -21,7 +21,7 @@ function ket(F::FittingProblem) return first(induced_subgraph(tensornetwork(state(F)), ket_vertices)) end -function extract!(region_iter::RegionIterator{<:FittingProblem}; sweep, kws...) +function extract!(region_iter::RegionIterator{<:FittingProblem}) prob = problem(region_iter) region = current_region(region_iter) @@ -44,9 +44,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}; sweep, kws...) return local_tensor end -function update!( - local_tensor, region_iter::RegionIterator{<:FittingProblem}; outputlevel, kws... -) +function update!(local_tensor, region_iter::RegionIterator{<:FittingProblem}; outputlevel) F = problem(region_iter) region = current_region(region_iter) @@ -71,11 +69,10 @@ function fit_tensornetwork( nsweeps=25, nsites=1, outputlevel=0, - extract_kwargs=(;), - update_kwargs=(;), - insert_kwargs=(;), normalize=true, - kws..., + maxdim=default_kwargs(factorize).maxdim, + cutoff=default_kwargs(factorize).cutoff, + extra_sweep_kwargs..., ) bpc = BeliefPropagationCache(overlap_network, args...) ket_graph = first( @@ -85,11 +82,16 @@ function fit_tensornetwork( ket_graph, state=bpc, gauge_region=collect(vertices(ket_graph)) ) - insert_kwargs = (; insert_kwargs..., normalize, set_orthogonal_region=false) - common_sweep_kwargs = (; nsites, outputlevel, update_kwargs, insert_kwargs) - kwargs_array = [(; common_sweep_kwargs..., sweep=s) for s in 1:nsweeps] + insert!_kwargs = (; normalize, set_orthogonal_region=false) + update!_kwargs = (; outputlevel) + factorize_kwargs = (; maxdim, cutoff) + + sweep_kwargs = (; nsites, outputlevel, update!_kwargs, insert!_kwargs, factorize_kwargs) + kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps] + sweep_iter = SweepIterator(init_prob, kwargs_array) - converged_prob = sweep_solve(sweep_iter; outputlevel, kws...) + converged_prob = sweep_solve(sweep_iter) + return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob)) end @@ -109,12 +111,10 @@ end function ITensors.apply( A::ITensorNetwork, x::ITensorNetwork; - maxdim=default_maxdim(), - cutoff=default_cutoff(), - kwargs..., + maxdim=default_kwargs(factorize).maxdim, + sweep_kwargs..., ) init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim) overlap_network = inner_network(x, A, init_state) - insert_kwargs = (; trunc=(; cutoff, maxdim)) - return fit_tensornetwork(overlap_network; insert_kwargs, kwargs...) + return fit_tensornetwork(overlap_network; maxdim, sweep_kwargs...) end diff --git a/src/solvers/insert.jl b/src/solvers/insert.jl index b71fb05a..4650f524 100644 --- a/src/solvers/insert.jl +++ b/src/solvers/insert.jl @@ -1,19 +1,15 @@ using NamedGraphs: edgetype -function insert!( - local_tensor, - region_iterator; - normalize=false, - set_orthogonal_region=true, - sweep, - trunc=(;), - outputlevel=0, - kws..., -) - prob = problem(region_iterator) +function insert!(local_tensor, region_iter; kwargs...) + return _insert_fallback!( + local_tensor, region_iter; normalize=false, set_orthogonal_region=true, kwargs... + ) +end + +function _insert_fallback!(local_tensor, region_iter; normalize, set_orthogonal_region) + prob = problem(region_iter) - trunc = truncation_parameters(sweep; trunc...) - region = current_region(region_iterator) + region = current_region(region_iter) psi = copy(state(prob)) if length(region) == 1 C = local_tensor @@ -21,7 +17,11 @@ function insert!( e = edgetype(psi)(first(region), last(region)) indsTe = inds(psi[first(region)]) tags = ITensors.tags(psi, e) - U, C, spectrum = factorize(local_tensor, indsTe; tags, trunc...) + + U, C, spectrum = factorize( + local_tensor, indsTe; tags, current_kwargs(factorize, region_iter)... + ) + @preserve_graph psi[first(region)] = U prob = set_truncation_info!(prob; spectrum) else diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 77b86808..97b0a62a 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -101,9 +101,7 @@ function RegionIterator(problem; sweep, sweep_kwargs...) return RegionIterator(problem, plan, sweep) end -function region_plan(problem; kws...) - return euler_sweep(state(problem); kws...) -end +region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) # # SweepIterator diff --git a/src/solvers/options.jl b/src/solvers/options.jl index 528e8f0b..e33723f8 100644 --- a/src/solvers/options.jl +++ b/src/solvers/options.jl @@ -44,4 +44,10 @@ function current_kwargs(f::Function, iter::RegionIterator) rv = merge(default_kwargs(f, iter), region_kwargs) return rv end + +# Generic + +# I think these should be set independent of a function, but for now: +function default_kwargs(::typeof(factorize), ::Any) + return (; maxdim=typemax(Int), cutoff=0.0, mindim=1) end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 5fc77a16..c62d22aa 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -1,19 +1,15 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf +function default_kwargs(::typeof(subspace_expand!), ::Backend"densitymatrix", ::Any) + return (; north_pass=1) +end function subspace_expand!( - ::Backend"densitymatrix", - local_state::ITensor, - region_iterator; - expansion_factor, - max_expand, - north_pass=1, - trunc, - kws..., + ::Backend"densitymatrix", local_state::ITensor, region_iter; north_pass ) - prob = problem(region_iterator) + prob = problem(region_iter) - region = current_region(region_iterator) + region = current_region(region_iter) psi = copy(state(prob)) prev_vertex_set = setdiff(pos(operator(prob)), region) @@ -31,10 +27,9 @@ function subspace_expand!( basis_size = prod(dim.(uniqueinds(A, C))) expanded_maxdim = compute_expansion( - dim(a), basis_size; expansion_factor, max_expand, trunc.maxdim + dim(a), basis_size; current_kwargs(compute_expansion, region_iter)... ) expanded_maxdim <= 0 && return local_state - trunc = (; trunc..., maxdim=expanded_maxdim) envs = environments(operator(prob)) H = operator(operator(prob)) @@ -50,7 +45,7 @@ function subspace_expand!( sqrt_rho = conj_proj_A(sqrt_rho) end rho = sqrt_rho * dag(noprime(sqrt_rho)) - D, U = eigen(rho; trunc..., ishermitian=true) + D, U = eigen(rho; current_kwargs(eigen, region_iter)..., ishermitian=true) Uproj(T) = (T - prime(A, a) * (dag(prime(A, a)) * T)) for pass in 1:north_pass diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index f8549af1..529a844d 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -1,55 +1,38 @@ using NDTensors: NDTensors using NDTensors.BackendSelection: Backend, @Backend_str -default_expansion_factor() = 1.5 -default_max_expand() = typemax(Int) - -function subspace_expand!( - local_state, - region_iterator; - expansion_factor=default_expansion_factor(), - max_expand=default_max_expand(), - subspace_algorithm=nothing, - sweep, - trunc, - kws..., -) - expansion_factor = get_or_last(expansion_factor, sweep) - max_expand = get_or_last(max_expand, sweep) +function subspace_expand!(local_state, region_iter; subspace_algorithm) + backend = Backend(subspace_algorithm) + + if backend isa Backend"nothing" + return local_state + end + local_state = subspace_expand!( - Backend(subspace_algorithm), - local_state, - region_iterator; - expansion_factor, - max_expand, - trunc, - kws..., + backend, local_state, region_iter; current_kwargs(subspace_expand!, region_iter)... ) return local_state end -function subspace_expand!(backend, local_state, region_iterator; kws...) +function default_kwargs(::typeof(subspace_expand!), iter::RegionIterator) + backend = current_kwargs(extract!, iter).subspace_algorithm + return default_kwargs(subspace_expand!, Backend(backend), problem(iter)) +end +default_kwargs(::typeof(subspace_expand!), ::Backend, ::Any) = (;) + +function subspace_expand!(backend, local_state, region_iterator; kwargs...) + # We allow passing of any kwargs here is this method throws an error anyway return error( "Subspace expansion (subspace_expand!) not defined for requested combination of subspace_algorithm and problem types", ) end -function subspace_expand!(backend::Backend{:nothing}, local_state, region_iterator; kws...) - return local_state -end - -function compute_expansion( - current_dim, - basis_size; - expansion_factor=default_expansion_factor(), - max_expand=default_max_expand(), - maxdim=default_maxdim(), -) +function compute_expansion(current_dim, basis_size; expansion_factor, maxexpand, maxdim) # Note: expand_maxdim will be *added* to current bond dimension # Obtain expand_maxdim from expansion_factor expand_maxdim = ceil(Int, expansion_factor * current_dim) # Enforce max_expand keyword - expand_maxdim = min(max_expand, expand_maxdim) + expand_maxdim = min(maxexpand, expand_maxdim) # Restrict expand_maxdim below theoretical upper limit expand_maxdim = min(basis_size - current_dim, expand_maxdim) @@ -57,5 +40,11 @@ function compute_expansion( expand_maxdim = min(maxdim - current_dim, expand_maxdim) # Ensure expand_maxdim is non-negative expand_maxdim = max(0, expand_maxdim) + return expand_maxdim end +function default_kwargs(::typeof(compute_expansion), iter::RegionIterator) + # Derived default + maxdim = current_kwargs(factorize, iter).maxdim + return (; maxexpand=typemax(Int), expansion_factor=1.5, maxdim) +end diff --git a/test/solvers/test_applyexp.jl b/test/solvers/test_applyexp.jl index c45a4817..c3464c8d 100644 --- a/test/solvers/test_applyexp.jl +++ b/test/solvers/test_applyexp.jl @@ -21,101 +21,102 @@ function chain_plus_ancilla(; nchain) return g end -@testset "Test Tree Time Evolution" begin - outputlevel = 0 - - N = 10 - g = chain_plus_ancilla(; nchain=N) - - sites = siteinds("S=1/2", g) - - # Make Heisenberg model Hamiltonian - h = OpSum() - for j in 1:(N-1) - h += "Sz", j, "Sz", j + 1 - h += 1 / 2, "S+", j, "S-", j + 1 - h += 1 / 2, "S-", j, "S+", j + 1 - end - H = ttn(h, sites) - - # Make initial product state - state = Dict{Int,String}() - for (j, v) in enumerate(vertices(sites)) - state[v] = iseven(j) ? "Up" : "Dn" +@testset "Time Evolution" begin + + @testset "Test Tree Time Evolution" begin + outputlevel = 0 + + N = 10 + g = chain_plus_ancilla(; nchain=N) + + sites = siteinds("S=1/2", g) + + # Make Heisenberg model Hamiltonian + h = OpSum() + for j in 1:(N-1) + h += "Sz", j, "Sz", j + 1 + h += 1 / 2, "S+", j, "S-", j + 1 + h += 1 / 2, "S-", j, "S+", j + 1 + end + H = ttn(h, sites) + + # Make initial product state + state = Dict{Int,String}() + for (j, v) in enumerate(vertices(sites)) + state[v] = iseven(j) ? "Up" : "Dn" + end + psi0 = ttn(state, sites) + + cutoff = 1E-10 + maxdim = 100 + nsweeps = 5 + + nsites = 2 + factorize_kwargs = (; cutoff, maxdim) + E, gs_psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel) + (outputlevel >= 1) && println("2-site DMRG energy = ", E) + + nsites = 1 + tmax = 0.10 + time_range = 0.0:0.02:tmax + psi1_t = time_evolve(H, time_range, gs_psi; factorize_kwargs, nsites, outputlevel) + (outputlevel >= 1) && println("Done with $nsites-site TDVP") + + @test norm(psi1_t) > 0.999 + + nsites = 2 + psi2_t = time_evolve(H, time_range, gs_psi; factorize_kwargs, nsites, outputlevel) + (outputlevel >= 1) && println("Done with $nsites-site TDVP") + @test norm(psi2_t) > 0.999 + + @test abs(inner(psi1_t, gs_psi)) > 0.99 + @test abs(inner(psi1_t, psi2_t)) > 0.99 + + # Test that accumulated phase angle is E*tmax + z = inner(psi1_t, gs_psi) + @test atan(imag(z) / real(z)) ≈ E * tmax atol = 1E-4 end - psi0 = ttn(state, sites) - - cutoff = 1E-10 - maxdim = 100 - nsweeps = 5 - - nsites = 2 - trunc = (; cutoff, maxdim) - E, gs_psi = dmrg(H, psi0; insert_kwargs=(; trunc), nsites, nsweeps, outputlevel) - (outputlevel >= 1) && println("2-site DMRG energy = ", E) - - insert_kwargs = (; trunc) - nsites = 1 - tmax = 0.10 - time_range = 0.0:0.02:tmax - psi1_t = time_evolve(H, time_range, gs_psi; insert_kwargs, nsites, outputlevel) - (outputlevel >= 1) && println("Done with $nsites-site TDVP") - - @test norm(psi1_t) > 0.999 - - nsites = 2 - psi2_t = time_evolve(H, time_range, gs_psi; insert_kwargs, nsites, outputlevel) - (outputlevel >= 1) && println("Done with $nsites-site TDVP") - @test norm(psi2_t) > 0.999 - @test abs(inner(psi1_t, gs_psi)) > 0.99 - @test abs(inner(psi1_t, psi2_t)) > 0.99 - - # Test that accumulated phase angle is E*tmax - z = inner(psi1_t, gs_psi) - @test atan(imag(z) / real(z)) ≈ E * tmax atol = 1E-4 -end - -@testset "Applyexp Time Point Handling" begin - N = 10 - g = named_path_graph(N) - sites = siteinds("S=1/2", g) - - # Make Heisenberg model Hamiltonian - h = OpSum() - for j in 1:(N-1) - h += "Sz", j, "Sz", j + 1 - h += 1 / 2, "S+", j, "S-", j + 1 - h += 1 / 2, "S-", j, "S+", j + 1 - end - H = ttn(h, sites) - - # Initial product state - state = Dict{Int,String}() - for (j, v) in enumerate(vertices(sites)) - state[v] = iseven(j) ? "Up" : "Dn" - end - psi0 = ttn(state, sites) - - nsites = 2 - trunc = (; cutoff=1E-8, maxdim=100) - insert_kwargs = (; trunc) - - # Test that all time points are reached and reported correctly - time_points = [0.0, 0.1, 0.25, 0.32, 0.4] - times = Real[] - function collect_times(sweep_iterator; kws...) - push!(times, ITensorNetworks.current_time(ITensorNetworks.problem(sweep_iterator))) - end - time_evolve(H, time_points, psi0; insert_kwargs, nsites, sweep_callback=collect_times, outputlevel=1) - @test times ≈ time_points atol = 10 * eps(Float64) - - # Test that all exponents are reached and reported correctly - exponent_points = [-0.0, -0.1, -0.25, -0.32, -0.4] - exponents = Real[] - function collect_exponents(sweep_iterator; kws...) - push!(exponents, ITensorNetworks.current_exponent(ITensorNetworks.problem(sweep_iterator))) + @testset "Applyexp Time Point Handling" begin + N = 10 + g = named_path_graph(N) + sites = siteinds("S=1/2", g) + + # Make Heisenberg model Hamiltonian + h = OpSum() + for j in 1:(N-1) + h += "Sz", j, "Sz", j + 1 + h += 1 / 2, "S+", j, "S-", j + 1 + h += 1 / 2, "S-", j, "S+", j + 1 + end + H = ttn(h, sites) + + # Initial product state + state = Dict{Int,String}() + for (j, v) in enumerate(vertices(sites)) + state[v] = iseven(j) ? "Up" : "Dn" + end + psi0 = ttn(state, sites) + + nsites = 2 + factorize_kwargs = (; cutoff=1E-8, maxdim=100) + + # Test that all time points are reached and reported correctly + time_points = [0.0, 0.1, 0.25, 0.32, 0.4] + times = Real[] + function collect_times(sweep_iterator; kws...) + push!(times, ITensorNetworks.current_time(ITensorNetworks.problem(sweep_iterator))) + end + time_evolve(H, time_points, psi0; factorize_kwargs, nsites, sweep_callback=collect_times, outputlevel=1) + @test times ≈ time_points atol = 10 * eps(Float64) + + # Test that all exponents are reached and reported correctly + exponent_points = [-0.0, -0.1, -0.25, -0.32, -0.4] + exponents = Real[] + function collect_exponents(sweep_iterator; kws...) + push!(exponents, ITensorNetworks.current_exponent(ITensorNetworks.problem(sweep_iterator))) + end + applyexp(H, exponent_points, psi0; factorize_kwargs, nsites, sweep_callback=collect_exponents, outputlevel=1) + @test exponents ≈ exponent_points atol = 10 * eps(Float64) end - applyexp(H, exponent_points, psi0; insert_kwargs, nsites, sweep_callback=collect_exponents, outputlevel=1) - @test exponents ≈ exponent_points atol = 10 * eps(Float64) end diff --git a/test/solvers/test_eigsolve.jl b/test/solvers/test_eigsolve.jl index 1d18a6d8..5a29c2c4 100644 --- a/test/solvers/test_eigsolve.jl +++ b/test/solvers/test_eigsolve.jl @@ -38,15 +38,16 @@ include("utilities/tree_graphs.jl") cutoff = 1E-5 maxdim = 40 + + factorize_kwargs = (; cutoff, maxdim) + nsweeps = 5 # # Test 2-site DMRG without subspace expansion # nsites = 2 - trunc = (; cutoff, maxdim) - insert_kwargs = (; trunc) - E, psi = dmrg(H, psi0; insert_kwargs, nsites, nsweeps, outputlevel) + E, psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel) (outputlevel >= 1) && println("2-site DMRG energy = ", E) @test E ≈ Ex atol = 1E-5 @@ -55,10 +56,8 @@ include("utilities/tree_graphs.jl") # nsites = 1 nsweeps = 5 - trunc = (; cutoff, maxdim) - extract_kwargs = (; trunc, subspace_algorithm="densitymatrix") - insert_kwargs = (; trunc) - E, psi = dmrg(H, psi0; extract_kwargs, insert_kwargs, nsites, nsweeps, outputlevel) + extract!_kwargs = (; subspace_algorithm="densitymatrix") + E, psi = dmrg(H, psi0; extract!_kwargs, factorize_kwargs, nsites, nsweeps, outputlevel) (outputlevel >= 1) && println("1-site+subspace DMRG energy = ", E) @test E ≈ Ex atol = 1E-5 end From 0c9022c1c159c80258dbb2b5f4419f86fe1d7bcb Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 16:41:10 -0400 Subject: [PATCH 20/55] Swap order of local_state and region_iter args This is so region_iter (the mutating arg) appears first in the function sig --- src/solvers/applyexp.jl | 4 ++-- src/solvers/eigsolve.jl | 2 +- src/solvers/extract.jl | 2 +- src/solvers/fitting.jl | 2 +- src/solvers/insert.jl | 6 +++--- src/solvers/iterators.jl | 4 ++-- src/solvers/subspace/densitymatrix.jl | 4 +--- src/solvers/subspace/subspace.jl | 6 +++--- 8 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 1cba9f44..5a64d627 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -21,8 +21,8 @@ function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...) end function update!( - local_state, - region_iterator::RegionIterator{<:ApplyExpProblem}; + region_iterator::RegionIterator{<:ApplyExpProblem}, + local_state; nsites, exponent_step, solver=runge_kutta_solver, diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 2326ca04..794e6170 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -21,7 +21,7 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) end function update!( - local_state, region_iterator::RegionIterator{<:EigsolveProblem}; outputlevel, solver + region_iterator::RegionIterator{<:EigsolveProblem}, local_state; outputlevel, solver ) prob = problem(region_iterator) diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index d87d3f88..48c7a59d 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -13,7 +13,7 @@ function _extract_fallback!(region_iter::RegionIterator; subspace_algorithm) prob.state = psi - local_state = subspace_expand!(local_state, region_iter; subspace_algorithm) + local_state = subspace_expand!(region_iter, local_state; subspace_algorithm) shifted_operator = position(operator(prob), state(prob), region) prob.operator = shifted_operator diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 0ec7e536..485cfb4c 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -44,7 +44,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}) return local_tensor end -function update!(local_tensor, region_iter::RegionIterator{<:FittingProblem}; outputlevel) +function update!(region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel) F = problem(region_iter) region = current_region(region_iter) diff --git a/src/solvers/insert.jl b/src/solvers/insert.jl index 4650f524..0ce49673 100644 --- a/src/solvers/insert.jl +++ b/src/solvers/insert.jl @@ -1,12 +1,12 @@ using NamedGraphs: edgetype -function insert!(local_tensor, region_iter; kwargs...) +function insert!(region_iter, local_tensor; kwargs...) return _insert_fallback!( - local_tensor, region_iter; normalize=false, set_orthogonal_region=true, kwargs... + region_iter, local_tensor; normalize=false, set_orthogonal_region=true, kwargs... ) end -function _insert_fallback!(local_tensor, region_iter; normalize, set_orthogonal_region) +function _insert_fallback!(region_iter, local_tensor; normalize, set_orthogonal_region) prob = problem(region_iter) region = current_region(region_iter) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 97b0a62a..295a8a63 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -90,8 +90,8 @@ end function compute!(iter::RegionIterator) local_state = extract!(iter; current_kwargs(extract!, iter)...) - local_state = update!(local_state, iter; current_kwargs(update!, iter)...) - insert!(local_state, iter; current_kwargs(insert!, iter)...) + local_state = update!(iter, local_state; current_kwargs(update!, iter)...) + insert!(iter, local_state; current_kwargs(insert!, iter)...) return iter end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index c62d22aa..75e38ecf 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -4,9 +4,7 @@ using Printf: @printf function default_kwargs(::typeof(subspace_expand!), ::Backend"densitymatrix", ::Any) return (; north_pass=1) end -function subspace_expand!( - ::Backend"densitymatrix", local_state::ITensor, region_iter; north_pass -) +function subspace_expand!(::Backend"densitymatrix", region_iter, local_state; north_pass) prob = problem(region_iter) region = current_region(region_iter) diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index 529a844d..52fc5b13 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -1,7 +1,7 @@ using NDTensors: NDTensors using NDTensors.BackendSelection: Backend, @Backend_str -function subspace_expand!(local_state, region_iter; subspace_algorithm) +function subspace_expand!(region_iter, local_state; subspace_algorithm) backend = Backend(subspace_algorithm) if backend isa Backend"nothing" @@ -9,7 +9,7 @@ function subspace_expand!(local_state, region_iter; subspace_algorithm) end local_state = subspace_expand!( - backend, local_state, region_iter; current_kwargs(subspace_expand!, region_iter)... + backend, region_iter, local_state; current_kwargs(subspace_expand!, region_iter)... ) return local_state end @@ -20,7 +20,7 @@ function default_kwargs(::typeof(subspace_expand!), iter::RegionIterator) end default_kwargs(::typeof(subspace_expand!), ::Backend, ::Any) = (;) -function subspace_expand!(backend, local_state, region_iterator; kwargs...) +function subspace_expand!(backend, region_iterator, local_state; kwargs...) # We allow passing of any kwargs here is this method throws an error anyway return error( "Subspace expansion (subspace_expand!) not defined for requested combination of subspace_algorithm and problem types", From a9be11e5c4051196327a178f68ecf39ef90b1bec Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 16:41:20 -0400 Subject: [PATCH 21/55] Add some unit tests for the defaults --- test/solvers/test_defaults.jl | 43 +++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 test/solvers/test_defaults.jl diff --git a/test/solvers/test_defaults.jl b/test/solvers/test_defaults.jl new file mode 100644 index 00000000..4cd2fe6b --- /dev/null +++ b/test/solvers/test_defaults.jl @@ -0,0 +1,43 @@ +using Test: @test, @testset +using ITensorNetworks: AbstractProblem, default_kwargs, current_kwargs, RegionIterator, problem + +module KwargsTestModule + +using ITensorNetworks +using ITensorNetworks: AbstractProblem + +export TestProblem, NotOurTestProblem, test_function + +struct TestProblem <: AbstractProblem end +struct NotOurTestProblem <: AbstractProblem end + +test_function(; bool=false, int=0) = bool, int + +function ITensorNetworks.default_kwargs(::typeof(test_function), ::Type{<:AbstractProblem}) + return (; int=3) +end +function ITensorNetworks.default_kwargs(::typeof(test_function), ::Type{<:TestProblem}) + return (; bool=true) +end + +end # KwargsTestModule + +@testset "Default kwargs" begin + using .KwargsTestModule: TestProblem, NotOurTestProblem, test_function + + our_iter = RegionIterator(TestProblem(), ["region" => (; int=1)], 1) + not_our_iter = RegionIterator(NotOurTestProblem(), ["region" => (; int=2)], 1) + + # Test dispatch + @test default_kwargs(test_function, our_iter) == (; bool=true) + @test default_kwargs(test_function, problem(our_iter)) == (; bool=true) + @test default_kwargs(test_function, typeof(problem(our_iter))) == (; bool=true) + + @test default_kwargs(test_function, not_our_iter) == (; int=3) + @test default_kwargs(test_function, problem(not_our_iter)) == (; int=3) + @test default_kwargs(test_function, typeof(problem(not_our_iter))) == (; int=3) + + @test test_function(; current_kwargs(test_function, our_iter)...) == (true, 0) + @test test_function(; current_kwargs(test_function, not_our_iter)...) == (false, 3) + +end From 4d520889a5e07ffba45527e1b1d4ad4d6fe750c9 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 16:43:13 -0400 Subject: [PATCH 22/55] Rename file options.jl -> test_default_kwargs.jl --- src/ITensorNetworks.jl | 2 +- src/solvers/{options.jl => default_kwargs.jl} | 0 test/solvers/{test_defaults.jl => test_default_kwargs.jl} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename src/solvers/{options.jl => default_kwargs.jl} (100%) rename test/solvers/{test_defaults.jl => test_default_kwargs.jl} (100%) diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 7900e05a..af39aeee 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -66,7 +66,7 @@ include("solvers/abstract_problem.jl") include("solvers/eigsolve.jl") include("solvers/applyexp.jl") include("solvers/fitting.jl") -include("solvers/options.jl") +include("solvers/default_kwargs.jl") include("apply.jl") include("inner.jl") diff --git a/src/solvers/options.jl b/src/solvers/default_kwargs.jl similarity index 100% rename from src/solvers/options.jl rename to src/solvers/default_kwargs.jl diff --git a/test/solvers/test_defaults.jl b/test/solvers/test_default_kwargs.jl similarity index 100% rename from test/solvers/test_defaults.jl rename to test/solvers/test_default_kwargs.jl From 613d5336f692880a892e9163479ef6952556cde0 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 16:58:28 -0400 Subject: [PATCH 23/55] Fix `euler_sweep` returning kwargs not as `NamedTuple` --- src/solvers/region_plans/euler_plans.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/solvers/region_plans/euler_plans.jl b/src/solvers/region_plans/euler_plans.jl index 6c5304f9..68548fdc 100644 --- a/src/solvers/region_plans/euler_plans.jl +++ b/src/solvers/region_plans/euler_plans.jl @@ -2,6 +2,8 @@ using Graphs: dst, src using NamedGraphs.GraphsExtensions: default_root_vertex function euler_sweep(graph; nsites, root_vertex=default_root_vertex(graph), sweep_kwargs...) + sweep_kwargs = (; nsites, root_vertex, sweep_kwargs...) + if nsites == 1 vertices = euler_tour_vertices(graph, root_vertex) sweep = [[v] => sweep_kwargs for v in vertices] From 20bf7830b15863d6c9a5fef42c45d53ba04bb3db Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 17:05:52 -0400 Subject: [PATCH 24/55] The `sweep_solve` callbacks now get called without any keyword arguments. --- src/solvers/applyexp.jl | 9 ++++----- src/solvers/eigsolve.jl | 4 ++-- src/solvers/sweep_solve.jl | 5 ++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 5a64d627..60b356d9 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -57,9 +57,8 @@ end function default_sweep_callback( sweep_iterator::SweepIterator{<:ApplyExpProblem}; exponent_description="exponent", - outputlevel, + outputlevel=0, process_time=identity, - kwargs..., ) if outputlevel >= 1 the_problem = problem(sweep_iterator) @@ -89,7 +88,7 @@ function applyexp( ] sweep_iter = SweepIterator(init_prob, kws_array) - converged_prob = sweep_solve(sweep_callback, sweep_iter; outputlevel=0) + converged_prob = sweep_solve(sweep_callback, sweep_iter) return state(converged_prob) end @@ -110,8 +109,8 @@ function time_evolve( time_points, init_state; process_time=process_real_times, - sweep_callback=(a...; k...) -> - default_sweep_callback(a...; exponent_description="time", process_time, k...), + sweep_callback=iter -> + default_sweep_callback(iter; exponent_description="time", process_time), sweep_kwargs..., ) exponents = [-im * t for t in time_points] diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 794e6170..aef7ad8c 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -46,7 +46,7 @@ function default_kwargs(::typeof(update!), ::Type{<:EigsolveProblem}) end function default_sweep_callback( - sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel + sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0 ) if outputlevel >= 1 nsweeps = length(sweep_iterator) @@ -69,7 +69,7 @@ function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_ state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...) - prob = sweep_solve(sweep_iter; outputlevel) + prob = sweep_solve(sweep_iter) return eigenvalue(prob), state(prob) end diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index 9273bad9..7aacea7d 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -11,14 +11,13 @@ function sweep_solve( sweep_iterator; sweep_callback=default_sweep_callback, region_callback=default_region_callback, - outputlevel=0, ) # Don't compute the region iteration automatically as we wish to insert a callback. for _ in NoComputeStep(sweep_iterator) for _ in region_iterator(sweep_iterator) - region_callback(sweep_iterator; outputlevel=outputlevel) + region_callback(sweep_iterator) end - sweep_callback(sweep_iterator; outputlevel=outputlevel) + sweep_callback(sweep_iterator) end return problem(sweep_iterator) end From 568c631e9604b0d34990f593dca33ad45593c86f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 7 Oct 2025 17:31:59 -0400 Subject: [PATCH 25/55] Some minor refactoring of the iterators. - Reordered the struct fields to be consistant with each other - Some field and function renames --- src/solvers/iterators.jl | 49 +++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 295a8a63..49cab7ef 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -41,13 +41,22 @@ end mutable struct RegionIterator{Problem,RegionPlan} <: AbstractNetworkIterator problem::Problem region_plan::RegionPlan - const sweep::Int which_region::Int + const which_sweep::Int function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P,R} - return new{P,R}(problem, region_plan, sweep, 1) + return new{P,R}(problem, region_plan, 1, sweep) end end +function RegionIterator(problem; sweep, sweep_kwargs...) + plan = region_plan(problem; sweep_kwargs...) + return RegionIterator(problem, plan, sweep) +end + +function new_region_iterator(iterator::RegionIterator; sweep_kwargs...) + return RegionIterator(iterator.problem; sweep_kwargs...) +end + state(region_iter::RegionIterator) = region_iter.which_region Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) @@ -74,11 +83,10 @@ function prev_region(region_iter::RegionIterator) end function next_region(region_iter::RegionIterator) - is_last_region(region_iter) && return nothing + laststep(region_iter) && return nothing next, _ = region_iter.region_plan[region_iter.which_region + 1] return next end -is_last_region(region_iter::RegionIterator) = length(region_iter) === state(region_iter) # # Functions associated with RegionIterator @@ -96,45 +104,44 @@ function compute!(iter::RegionIterator) return iter end -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep, sweep_kwargs...) - return RegionIterator(problem, plan, sweep) -end - region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) # # SweepIterator # -mutable struct SweepIterator{Problem} <: AbstractNetworkIterator - sweep_kws +mutable struct SweepIterator{Problem,Iter} <: AbstractNetworkIterator region_iter::RegionIterator{Problem} + sweep_kwargs::Iterators.Stateful{Iter} which_sweep::Int - function SweepIterator(problem, sweep_kws) - sweep_kws = Iterators.Stateful(sweep_kws) - first_kwargs, _ = Iterators.peel(sweep_kws) + function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob,Iter} + stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) + first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs) region_iter = RegionIterator(problem; sweep=1, first_kwargs...) - return new{typeof(problem)}(sweep_kws, region_iter, 1) + return new{Prob,Iter}(region_iter, stateful_sweep_kwargs, 1) end end -laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kws)) +laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kws) +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) function increment!(sweep_iter::SweepIterator) sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kws) - sweep_iter.region_iter = RegionIterator( - problem(sweep_iter); sweep=state(sweep_iter), sweep_kwargs... - ) + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) + update_region_iterator!(sweep_iter; sweep_kwargs...) return sweep_iter end +function update_region_iterator!(iterator::SweepIterator; kwargs...) + sweep = state(iterator) + iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...) + return iterator +end + function compute!(sweep_iter::SweepIterator) for _ in sweep_iter.region_iter # TODO: Is it sensible to execute the default region callback function? From fed9137a59888f5491c933315d8d5dc096e9def5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 9 Oct 2025 10:54:35 -0400 Subject: [PATCH 26/55] The `EachRegion` adapter now flattens the nested Sweep/Region iterators into a single iterators over the regions. The length of the iterator is then nsweeps * nregions. --- src/solvers/adapters.jl | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index c1d139d0..5bed8862 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -17,23 +17,32 @@ compute!(adapter::NoComputeStep) = adapter NoComputeStep(adapter::NoComputeStep) = adapter """ - struct EachRegion{RegionIterator} <: AbstractNetworkIterator + struct EachRegion{SweepIterator} <: AbstractNetworkIterator -Wapper adapter that returns a tuple (region, kwargs) at each step rather than the iterator -itself. +Adapter that flattens the each region iterator in the parent sweep iterator into a single +iterator, returning `region => kwargs`. """ -struct EachRegion{R<:RegionIterator} <: AbstractNetworkIterator - parent::R +struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator + parent::SI end -# Essential definitions -Base.length(adapter::EachRegion) = length(adapter.parent) -state(adapter::EachRegion) = state(adapter.parent) -increment!(adapter::EachRegion) = state(adapter.parent) +# In keeping with Julia convention. +eachregion(iter::SweepIterator) = EachRegion(iter) +# Essential definitions +function laststep(adapter::EachRegion) + region_iter = region_iterator(adapter.parent) + return laststep(adapter.parent) && laststep(region_iter) +end +function increment!(adapter::EachRegion) + region_iter = region_iterator(adapter.parent) + laststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) + return adapter +end function compute!(adapter::EachRegion) - # Do the usual compute! for RegionIterator - compute!(adapter.parent) - # But now lets return something useful - return current_region_plan(adapter) + region_iter = region_iterator(adapter.parent) + compute!(region_iter) + return current_region_plan(region_iter) +end + end From 4ce453e3ae827f1c8a8b895e2ca028e090aeec90 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 9 Oct 2025 10:54:56 -0400 Subject: [PATCH 27/55] Add tests for `EachRegion` and `eachregion` wrapper functions --- test/solvers/test_iterators.jl | 80 +++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl index a39722c1..5fc77fb8 100644 --- a/test/solvers/test_iterators.jl +++ b/test/solvers/test_iterators.jl @@ -1,10 +1,21 @@ using Test: @test, @testset -using ITensorNetworks: laststep, state, increment!, compute! +using ITensorNetworks: SweepIterator, laststep, state, increment!, compute!, eachregion module TestIteratorUtils using ITensorNetworks +struct TestProblem <: ITensorNetworks.AbstractProblem + data::Vector{Int} +end +ITensorNetworks.region_plan(::TestProblem) = [:a => (; val=1), :b => (; val=2)] +function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) + kwargs = ITensorNetworks.current_region_kwargs(iter) + push!(ITensorNetworks.problem(iter).data, kwargs.val) + return iter +end + + mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator state::Int max::Int @@ -35,7 +46,7 @@ end @testset "Iterators" begin - using .TestIteratorUtils: TestIterator, SquareAdapter + using .TestIteratorUtils: TestIterator, SquareAdapter, TestProblem @testset "`AbstractNetworkIterator` Interface" begin TI = TestIterator(1, 4, []) @@ -104,23 +115,62 @@ end TI = TestIterator(1, 5, []) SA = SquareAdapter(TI) - i = 0 - for rv in SA - i += 1 - @test rv isa Int - @test rv == i^2 - @test state(SA) == i + @testset "Generic" begin + + i = 0 + for rv in SA + i += 1 + @test rv isa Int + @test rv == i^2 + @test state(SA) == i + end + + @test laststep((SA)) + + TI = TestIterator(1, 5, []) + SA = SquareAdapter(TI) + + SA_c = collect(SA) + + @test SA_c isa Vector + @test length(SA_c) == 5 + @test SA_c == [1, 4, 9, 16, 25] + end - @test laststep((SA)) + @testset "EachRegion" begin + prob = TestProblem([]) + prob_region = TestProblem([]) - TI = TestIterator(1, 5, []) - SA = SquareAdapter(TI) + SI = SweepIterator(prob, 5) + SI_region = SweepIterator(prob_region, 5) + + callback = [] + callback_region = [] + + let i = 1 + for _ in SI + push!(callback, i) + i += 1 + end + end + + @test length(callback) == 5 - SA_c = collect(SA) + let i = 1 + for _ in eachregion(SI_region) + push!(callback_region, i) + i += 1 + end + end - @test SA_c isa Vector - @test length(SA_c) == 5 - @test SA_c == [1, 4, 9, 16, 25] + @test length(callback_region) == 10 + + @test prob.data == prob_region.data + + @test prob.data[1:2:end] == fill(1, 5) + @test prob.data[2:2:end] == fill(2, 5) + + end end end From c59a9c5ba0dd92ee2c6698bc166ebd8eff07505a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 9 Oct 2025 10:56:03 -0400 Subject: [PATCH 28/55] Rename `laststep` -> `islaststep` in fitting with Julia conventions. --- src/solvers/adapters.jl | 8 ++++---- src/solvers/iterators.jl | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 5bed8862..b2072f1c 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -9,7 +9,7 @@ struct NoComputeStep{S<:AbstractNetworkIterator} <: AbstractNetworkIterator parent::S end -laststep(adapter::NoComputeStep) = laststep(adapter.parent) +islaststep(adapter::NoComputeStep) = islaststep(adapter.parent) state(adapter::NoComputeStep) = state(adapter.parent) increment!(adapter::NoComputeStep) = increment!(adapter.parent) compute!(adapter::NoComputeStep) = adapter @@ -30,13 +30,13 @@ end eachregion(iter::SweepIterator) = EachRegion(iter) # Essential definitions -function laststep(adapter::EachRegion) +function islaststep(adapter::EachRegion) region_iter = region_iterator(adapter.parent) - return laststep(adapter.parent) && laststep(region_iter) + return islaststep(adapter.parent) && islaststep(region_iter) end function increment!(adapter::EachRegion) region_iter = region_iterator(adapter.parent) - laststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) + islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) return adapter end function compute!(adapter::EachRegion) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 49cab7ef..3190a696 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -9,10 +9,10 @@ this call is implict. Termination of the iterator is controlled by the function abstract type AbstractNetworkIterator end # We use greater than or equals here as we increment the state at the start of the iteration -laststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) +islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) function Base.iterate(iterator::AbstractNetworkIterator, init=true) - laststep(iterator) && return nothing + islaststep(iterator) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. @@ -83,7 +83,7 @@ function prev_region(region_iter::RegionIterator) end function next_region(region_iter::RegionIterator) - laststep(region_iter) && return nothing + islaststep(region_iter) && return nothing next, _ = region_iter.region_plan[region_iter.which_region + 1] return next end @@ -122,7 +122,7 @@ mutable struct SweepIterator{Problem,Iter} <: AbstractNetworkIterator end end -laststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) +islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) From 62195b677051ec2fa18db7a47208b91e68dcd51c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 9 Oct 2025 16:33:55 -0400 Subject: [PATCH 29/55] Overhaul `default_kwargs` such that it mirrors the function signatures of the associated function. - This comes at the cost of some verbosity regarding getting and splatting the kwargs from the region iterator, but the usefulness of `default_kwargs` is now much wider and also more well defined - Introduce macro `@default_kwargs` for doing this automatically. --- src/ITensorNetworks.jl | 2 +- src/solvers/applyexp.jl | 19 ++-- src/solvers/default_kwargs.jl | 140 +++++++++++++++++++------- src/solvers/eigsolve.jl | 21 ++-- src/solvers/extract.jl | 8 +- src/solvers/fitting.jl | 17 ++-- src/solvers/insert.jl | 10 +- src/solvers/iterators.jl | 28 +++++- src/solvers/subspace/densitymatrix.jl | 11 +- src/solvers/subspace/subspace.jl | 30 +++--- 10 files changed, 177 insertions(+), 109 deletions(-) diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index af39aeee..ec63ee28 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -45,6 +45,7 @@ include("treetensornetworks/projttns/projttn.jl") include("treetensornetworks/projttns/projttnsum.jl") include("treetensornetworks/projttns/projouterprodttn.jl") +include("solvers/default_kwargs.jl") include("solvers/local_solvers/eigsolve.jl") include("solvers/local_solvers/exponentiate.jl") include("solvers/local_solvers/runge_kutta.jl") @@ -66,7 +67,6 @@ include("solvers/abstract_problem.jl") include("solvers/eigsolve.jl") include("solvers/applyexp.jl") include("solvers/fitting.jl") -include("solvers/default_kwargs.jl") include("apply.jl") include("inner.jl") diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 60b356d9..0a91067b 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -20,31 +20,34 @@ function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...) return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...) end -function update!( - region_iterator::RegionIterator{<:ApplyExpProblem}, +@default_kwargs function update!( + region_iter::RegionIterator{<:ApplyExpProblem}, local_state; nsites, exponent_step, solver=runge_kutta_solver, - kws..., ) - prob = problem(region_iterator) + prob = problem(region_iter) iszero(abs(exponent_step)) && return local_state + solver_kwargs = region_kwargs(solver, region_iter) + local_state, _ = solver( - x -> optimal_map(operator(prob), x), exponent_step, local_state; kws... + x -> optimal_map(operator(prob), x), exponent_step, local_state; solver_kwargs... ) if nsites == 1 - curr_reg = current_region(region_iterator) - next_reg = next_region(region_iterator) + curr_reg = current_region(region_iter) + next_reg = next_region(region_iter) if !isnothing(next_reg) && next_reg != curr_reg next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg)) v1, v2 = src(next_edge), dst(next_edge) psi = copy(state(prob)) psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2])) shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2)) - R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...) + R_t, _ = solver( + x -> optimal_map(shifted_operator, x), -exponent_step, R; solver_kwargs... + ) local_state = psi[v1] * R_t end end diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index e33723f8..de831cd6 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -1,53 +1,115 @@ +using MacroTools + """ - default_kwargs(f, [obj = Any]) + default_kwargs(f::Function, args...; kwargs...) -Return the default keyword arguments for the function `f`. These defaults may be -derived from the contents or type of the second arugment `obj`. +Returns a set of default keyword arguments, as a `NamedTuple`, for the function `f` +depending on an arbitrary number of positional arguments. Any number of these default +keyword arguments can optionally be overwritten by passing the the keyword as a +keyword argument to this function. +""" +function default_kwargs(f::Function, args...; kwargs...) + return default_kwargs(f, map(typeof, args)...; kwargs...) +end +default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...) -## Interface +""" + @default_kwargs -Given a function `f`, one can optionally set the default keyword arguments for this -function by specializing either of the following two-argument methods: +Automatically define a `default_kwargs` method for a given function. This macro should +be applied before a function definition: ``` -ITensorNetworks.default_kwargs(::typeof(f), prob::AbstractProblem) -ITensorNetworks.default_kwargs(::typeof(f), ::Type{<:AbstractProblem}) +@default_kwargs astypes = true function f(args...; kwargs...) + ... +end +``` +If `astypes = true` then the `default_kwargs` method is defined in the +type domain with respect to `args`, i.e. +``` +default_kwargs(::typeof(f), arg::T; kwargs...) # astypes = false +default_kwargs(::typeof(f), arg::Type{<:T}; kwargs...) # astypes = true ``` -If one does not require the contents of `prob::Prob` to generate the defaults then it is -recommended to dispatch on `Type{<:Prob}` directly (second method) so the defaults -can be accessed without constructing an instance of a `Prob`. - -The return value of `default_kwargs` should be a `NamedTuple`, and will overwrite any -default values set in the function signature. """ -default_kwargs(f) = default_kwargs(f, Any) -default_kwargs(f, obj) = _default_kwargs_fallback(f, obj) - -# To avoid annoying potential method ambiguities. -function _default_kwargs_fallback(f, iter::RegionIterator) - return default_kwargs(f, problem(iter)) -end -function _default_kwargs_fallback(f, problem::AbstractProblem) - return default_kwargs(f, typeof(problem)) +macro default_kwargs(args...) + kwargs = (;) + for opt in args + if @capture(opt, key_ = val_) + @info "" key val + kwargs = merge(kwargs, NamedTuple{(key,)}((val,))) + elseif opt === last(args) + return default_kwargs_macro(opt; kwargs...) + else + throw(ArgumentError("Unknown expression object")) + end + end end -# Eventually we reach this if nothing is specialized. -_default_kwargs_fallback(::Any, ::DataType) = (;) +function default_kwargs_macro(function_def; astypes=true) + if !isdef(function_def) + throw( + ArgumentError("The @default_kwargs macro must be followed by a function definition") + ) + end -""" - current_kwargs(f, iter::RegionIterator) + ex = splitdef(function_def) + new_ex = deepcopy(ex) -Return the keyword arguments to be passed to the function `f` for the current region -defined by the stateful iterator `iter`. -""" -function current_kwargs(f::Function, iter::RegionIterator) - region_kwargs = get(current_region_kwargs(iter), Symbol(f, :_kwargs), (;)) - rv = merge(default_kwargs(f, iter), region_kwargs) - return rv -end + prev_kwargs = [] + + # Give very positional argument a name and escape the type. + ex[:args] = map(ex[:args]) do arg + @capture(arg, (name_::T_) | (::T_) | name_) + if isnothing(name) + name = gensym() + end + if isnothing(T) + T = :Any + end + return :($(name)::$(esc(T))) + end + + # Replacing the kwargs values with the output of `default_kwargs` + ex[:kwargs] = map(ex[:kwargs]) do kw + @capture(kw, (key_::T_ = val_) | (key_ = val_) | key_) + if !isnothing(val) + kw.args[2] = + :(default_kwargs($(esc(ex[:name])), $(ex[:args]...); $(prev_kwargs...)).$key) + end + push!(prev_kwargs, key) + return kw + end + + # Promote to the type domain if wanted + if astypes + new_ex[:args] = map(ex[:args]) do arg + @capture(arg, name_::T_) + return :($(name)::Type{<:$T}) + end + end + + new_ex[:name] = :(ITensorNetworks.default_kwargs) + new_ex[:args] = convert(Vector{Any}, ex[:args]) -# Generic + new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name]))))) -# I think these should be set independent of a function, but for now: -function default_kwargs(::typeof(factorize), ::Any) - return (; maxdim=typemax(Int), cutoff=0.0, mindim=1) + # Escape anything on the right-hand side of a keyword definition. + new_ex[:kwargs] = map(new_ex[:kwargs]) do kw + @capture(kw, (key_ = val_) | key_) + if !isnothing(val) + kw.args[2] = esc(val) + end + return kw + end + + new_ex[:body] = :(return (; $(prev_kwargs...))) + + # Escape the actual function name + ex[:name] = :($(esc(ex[:name]))) + + rv = quote + $(combinedef(ex)) + $(combinedef(new_ex)) + end + + return rv end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index aef7ad8c..2f68ca9c 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -20,31 +20,26 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) return E end -function update!( - region_iterator::RegionIterator{<:EigsolveProblem}, local_state; outputlevel, solver +@default_kwargs function update!( + region_iter::RegionIterator{<:EigsolveProblem}, + local_state; + outputlevel=0, + solver=eigsolve_solver, ) - prob = problem(region_iterator) + prob = problem(region_iter) eigval, local_state = solver( - ψ -> optimal_map(operator(prob), ψ), - local_state; - current_kwargs(solver, region_iterator)..., + ψ -> optimal_map(operator(prob), ψ), local_state; region_kwargs(solver, region_iter)... ) prob.eigenvalue = eigval if outputlevel >= 2 - @printf( - " Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob) - ) + @printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob)) end return local_state end -function default_kwargs(::typeof(update!), ::Type{<:EigsolveProblem}) - return (; outputlevel=0, solver=eigsolve_solver) -end - function default_sweep_callback( sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0 ) diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index 48c7a59d..526f58da 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -1,10 +1,4 @@ -function extract!(iter; kwargs...) - return _extract_fallback!(iter; subspace_algorithm="nothing", kwargs...) -end - -# Internal function such that a method error can be thrown while still allowing a user -# to specialize on `extract!` -function _extract_fallback!(region_iter::RegionIterator; subspace_algorithm) +function extract!(region_iter::RegionIterator; subspace_algorithm="nothing") prob = problem(region_iter) region = current_region(region_iter) diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 485cfb4c..668dd0e4 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -44,7 +44,9 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}) return local_tensor end -function update!(region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel) +@default_kwargs function update!( + region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel=0 +) F = problem(region_iter) region = current_region(region_iter) @@ -70,8 +72,7 @@ function fit_tensornetwork( nsites=1, outputlevel=0, normalize=true, - maxdim=default_kwargs(factorize).maxdim, - cutoff=default_kwargs(factorize).cutoff, + factorize_kwargs, extra_sweep_kwargs..., ) bpc = BeliefPropagationCache(overlap_network, args...) @@ -84,7 +85,6 @@ function fit_tensornetwork( insert!_kwargs = (; normalize, set_orthogonal_region=false) update!_kwargs = (; outputlevel) - factorize_kwargs = (; maxdim, cutoff) sweep_kwargs = (; nsites, outputlevel, update!_kwargs, insert!_kwargs, factorize_kwargs) kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps] @@ -109,12 +109,11 @@ end #end function ITensors.apply( - A::ITensorNetwork, - x::ITensorNetwork; - maxdim=default_kwargs(factorize).maxdim, - sweep_kwargs..., + A::ITensorNetwork, x::ITensorNetwork; maxdim=typemax(Int), cutoff=0.0, sweep_kwargs... ) init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim) overlap_network = inner_network(x, A, init_state) - return fit_tensornetwork(overlap_network; maxdim, sweep_kwargs...) + return fit_tensornetwork( + overlap_network; factorize_kwargs=(; maxdim, cutoff), sweep_kwargs... + ) end diff --git a/src/solvers/insert.jl b/src/solvers/insert.jl index 0ce49673..7ab4bdb2 100644 --- a/src/solvers/insert.jl +++ b/src/solvers/insert.jl @@ -1,12 +1,6 @@ using NamedGraphs: edgetype -function insert!(region_iter, local_tensor; kwargs...) - return _insert_fallback!( - region_iter, local_tensor; normalize=false, set_orthogonal_region=true, kwargs... - ) -end - -function _insert_fallback!(region_iter, local_tensor; normalize, set_orthogonal_region) +function insert!(region_iter, local_tensor; normalize=false, set_orthogonal_region=true) prob = problem(region_iter) region = current_region(region_iter) @@ -19,7 +13,7 @@ function _insert_fallback!(region_iter, local_tensor; normalize, set_orthogonal_ tags = ITensors.tags(psi, e) U, C, spectrum = factorize( - local_tensor, indsTe; tags, current_kwargs(factorize, region_iter)... + local_tensor, indsTe; tags, region_kwargs(factorize, region_iter)... ) @preserve_graph psi[first(region)] = U diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 3190a696..69744d2a 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -71,10 +71,13 @@ function current_region(region_iter::RegionIterator) return region end -function current_region_kwargs(region_iter::RegionIterator) +function region_kwargs(region_iter::RegionIterator) _, kwargs = current_region_plan(region_iter) return kwargs end +function region_kwargs(f::Function, iter::RegionIterator) + return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) +end function prev_region(region_iter::RegionIterator) state(region_iter) <= 1 && return nothing @@ -96,10 +99,27 @@ function increment!(region_iter::RegionIterator) return region_iter end +# Purely for our convenience: +function extract!_kwargs(iter) + f = extract! + kwargs = region_kwargs(f, iter) + return default_kwargs(f, iter; kwargs...) +end +function update!_kwargs(iter, local_state) + f = update! + kwargs = region_kwargs(f, iter) + return default_kwargs(f, iter, local_state; kwargs...) +end +function insert!_kwargs(iter, local_state) + f = insert! + kwargs = region_kwargs(f, iter) + return default_kwargs(f, iter, local_state; kwargs...) +end + function compute!(iter::RegionIterator) - local_state = extract!(iter; current_kwargs(extract!, iter)...) - local_state = update!(iter, local_state; current_kwargs(update!, iter)...) - insert!(iter, local_state; current_kwargs(insert!, iter)...) + local_state = extract!(iter; extract!_kwargs(iter)...) + local_state = update!(iter, local_state; update!_kwargs(iter, local_state)...) + insert!(iter, local_state; insert!_kwargs(iter, local_state)...) return iter end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 75e38ecf..c1f522b6 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -1,10 +1,9 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf -function default_kwargs(::typeof(subspace_expand!), ::Backend"densitymatrix", ::Any) - return (; north_pass=1) -end -function subspace_expand!(::Backend"densitymatrix", region_iter, local_state; north_pass) +@default_kwargs function subspace_expand!( + ::Backend"densitymatrix", region_iter, local_state; north_pass=1 +) prob = problem(region_iter) region = current_region(region_iter) @@ -25,7 +24,7 @@ function subspace_expand!(::Backend"densitymatrix", region_iter, local_state; no basis_size = prod(dim.(uniqueinds(A, C))) expanded_maxdim = compute_expansion( - dim(a), basis_size; current_kwargs(compute_expansion, region_iter)... + dim(a), basis_size; region_kwargs(compute_expansion, region_iter)... ) expanded_maxdim <= 0 && return local_state @@ -43,7 +42,7 @@ function subspace_expand!(::Backend"densitymatrix", region_iter, local_state; no sqrt_rho = conj_proj_A(sqrt_rho) end rho = sqrt_rho * dag(noprime(sqrt_rho)) - D, U = eigen(rho; current_kwargs(eigen, region_iter)..., ishermitian=true) + D, U = eigen(rho; region_kwargs(eigen, region_iter)..., ishermitian=true) Uproj(T) = (T - prime(A, a) * (dag(prime(A, a)) * T)) for pass in 1:north_pass diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index 52fc5b13..0b526604 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -1,24 +1,29 @@ using NDTensors: NDTensors using NDTensors.BackendSelection: Backend, @Backend_str -function subspace_expand!(region_iter, local_state; subspace_algorithm) +@default_kwargs function subspace_expand!( + region_iter, local_state; subspace_algorithm="nothing" +) backend = Backend(subspace_algorithm) if backend isa Backend"nothing" return local_state end + subspace_expand!_kwargs = default_kwargs( + subspace_expand!, + backend, + region_iter, + local_state; + region_kwargs(subspace_expand!, region_iter)..., + ) + local_state = subspace_expand!( - backend, region_iter, local_state; current_kwargs(subspace_expand!, region_iter)... + backend, region_iter, local_state; subspace_expand!_kwargs... ) - return local_state -end -function default_kwargs(::typeof(subspace_expand!), iter::RegionIterator) - backend = current_kwargs(extract!, iter).subspace_algorithm - return default_kwargs(subspace_expand!, Backend(backend), problem(iter)) + return local_state end -default_kwargs(::typeof(subspace_expand!), ::Backend, ::Any) = (;) function subspace_expand!(backend, region_iterator, local_state; kwargs...) # We allow passing of any kwargs here is this method throws an error anyway @@ -27,7 +32,9 @@ function subspace_expand!(backend, region_iterator, local_state; kwargs...) ) end -function compute_expansion(current_dim, basis_size; expansion_factor, maxexpand, maxdim) +function compute_expansion( + current_dim, basis_size; expansion_factor=1.5, maxexpand=typemax(Int), maxdim=typemax(Int) +) # Note: expand_maxdim will be *added* to current bond dimension # Obtain expand_maxdim from expansion_factor expand_maxdim = ceil(Int, expansion_factor * current_dim) @@ -43,8 +50,3 @@ function compute_expansion(current_dim, basis_size; expansion_factor, maxexpand, return expand_maxdim end -function default_kwargs(::typeof(compute_expansion), iter::RegionIterator) - # Derived default - maxdim = current_kwargs(factorize, iter).maxdim - return (; maxexpand=typemax(Int), expansion_factor=1.5, maxdim) -end From 917f2f1ae985aa04b83aeee1c20454e29cb17424 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 9 Oct 2025 16:38:30 -0400 Subject: [PATCH 30/55] Rename `NoComputeStep` to `IncrementOnly` I think this is clearer than `JustIncrement` which might not be clear to non-native English speakers (maybe), and avoids the case of `OnlyIncrement` being confused with "the only increment". --- src/solvers/adapters.jl | 12 ++++++------ src/solvers/sweep_solve.jl | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index b2072f1c..51e8682d 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -5,16 +5,16 @@ Iterator wrapper whos `compute!` function simply returns itself, doing nothing i process. This allows one to manually call a custom `compute!` or insert their own code it in the loop body in place of `compute!`. """ -struct NoComputeStep{S<:AbstractNetworkIterator} <: AbstractNetworkIterator +struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator parent::S end -islaststep(adapter::NoComputeStep) = islaststep(adapter.parent) -state(adapter::NoComputeStep) = state(adapter.parent) -increment!(adapter::NoComputeStep) = increment!(adapter.parent) -compute!(adapter::NoComputeStep) = adapter +islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) +state(adapter::IncrementOnly) = state(adapter.parent) +increment!(adapter::IncrementOnly) = increment!(adapter.parent) +compute!(adapter::IncrementOnly) = adapter -NoComputeStep(adapter::NoComputeStep) = adapter +IncrementOnly(adapter::IncrementOnly) = adapter """ struct EachRegion{SweepIterator} <: AbstractNetworkIterator diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index 7aacea7d..53343a65 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -13,7 +13,7 @@ function sweep_solve( region_callback=default_region_callback, ) # Don't compute the region iteration automatically as we wish to insert a callback. - for _ in NoComputeStep(sweep_iterator) + for _ in IncrementOnly(sweep_iterator) for _ in region_iterator(sweep_iterator) region_callback(sweep_iterator) end From 112d55ed4e3ef7285c09e90d4fcc321ee5c6ac47 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 10 Oct 2025 10:48:09 -0400 Subject: [PATCH 31/55] Remove @info statement and fix bug with `astypes` not promoting correctly. --- src/solvers/default_kwargs.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index de831cd6..5a5c1fe4 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -34,7 +34,6 @@ macro default_kwargs(args...) kwargs = (;) for opt in args if @capture(opt, key_ = val_) - @info "" key val kwargs = merge(kwargs, NamedTuple{(key,)}((val,))) elseif opt === last(args) return default_kwargs_macro(opt; kwargs...) @@ -80,16 +79,15 @@ function default_kwargs_macro(function_def; astypes=true) end # Promote to the type domain if wanted + new_ex[:args] = convert(Vector{Any}, ex[:args]) if astypes - new_ex[:args] = map(ex[:args]) do arg + new_ex[:args] = map(new_ex[:args]) do arg @capture(arg, name_::T_) return :($(name)::Type{<:$T}) end end new_ex[:name] = :(ITensorNetworks.default_kwargs) - new_ex[:args] = convert(Vector{Any}, ex[:args]) - new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name]))))) # Escape anything on the right-hand side of a keyword definition. From 0a9f127a0262ae9eb7ddccf09f8d88640a8dffd7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 10 Oct 2025 10:48:20 -0400 Subject: [PATCH 32/55] Update `default_kwargs` tests. --- test/solvers/test_default_kwargs.jl | 39 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl index 4cd2fe6b..6368176a 100644 --- a/test/solvers/test_default_kwargs.jl +++ b/test/solvers/test_default_kwargs.jl @@ -1,23 +1,21 @@ using Test: @test, @testset -using ITensorNetworks: AbstractProblem, default_kwargs, current_kwargs, RegionIterator, problem +using ITensorNetworks: AbstractProblem, default_kwargs, RegionIterator, problem, region_kwargs module KwargsTestModule using ITensorNetworks -using ITensorNetworks: AbstractProblem +using ITensorNetworks: AbstractProblem, @default_kwargs export TestProblem, NotOurTestProblem, test_function struct TestProblem <: AbstractProblem end struct NotOurTestProblem <: AbstractProblem end -test_function(; bool=false, int=0) = bool, int - -function ITensorNetworks.default_kwargs(::typeof(test_function), ::Type{<:AbstractProblem}) - return (; int=3) +@default_kwargs astypes = true function test_function(::AbstractProblem; bool=false, int=3) + return bool, int end -function ITensorNetworks.default_kwargs(::typeof(test_function), ::Type{<:TestProblem}) - return (; bool=true) +@default_kwargs astypes = true function test_function(::TestProblem; bool=true, int=0) + return bool, int end end # KwargsTestModule @@ -25,19 +23,24 @@ end # KwargsTestModule @testset "Default kwargs" begin using .KwargsTestModule: TestProblem, NotOurTestProblem, test_function - our_iter = RegionIterator(TestProblem(), ["region" => (; int=1)], 1) - not_our_iter = RegionIterator(NotOurTestProblem(), ["region" => (; int=2)], 1) + our_iter = RegionIterator(TestProblem(), ["region" => (; test_function_kwargs=(; int=1))], 1) + not_our_iter = RegionIterator(NotOurTestProblem(), ["region" => (; test_function_kwargs=(; int=2))], 1) + + kw = region_kwargs(test_function, our_iter) + @test kw == (; int=1) + kw_not = region_kwargs(test_function, not_our_iter) + @test kw_not == (; int=2) + + @info methods(default_kwargs) # Test dispatch - @test default_kwargs(test_function, our_iter) == (; bool=true) - @test default_kwargs(test_function, problem(our_iter)) == (; bool=true) - @test default_kwargs(test_function, typeof(problem(our_iter))) == (; bool=true) + @test default_kwargs(test_function, problem(our_iter)) == (; bool=true, int=0) + @test default_kwargs(test_function, problem(our_iter) |> typeof) == (; bool=true, int=0) - @test default_kwargs(test_function, not_our_iter) == (; int=3) - @test default_kwargs(test_function, problem(not_our_iter)) == (; int=3) - @test default_kwargs(test_function, typeof(problem(not_our_iter))) == (; int=3) + @test default_kwargs(test_function, problem(not_our_iter)) == (; bool=false, int=3) + @test default_kwargs(test_function, problem(not_our_iter) |> typeof) == (; bool=false, int=3) - @test test_function(; current_kwargs(test_function, our_iter)...) == (true, 0) - @test test_function(; current_kwargs(test_function, not_our_iter)...) == (false, 3) + @test test_function(problem(our_iter); default_kwargs(test_function, problem(our_iter); kw...)...) == (true, 1) + @test test_function(problem(not_our_iter); default_kwargs(test_function, problem(not_our_iter); kw_not...)...) == (false, 2) end From e35f32562e995d903e92a7d65fdb14aa2637c5e7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 13 Oct 2025 23:46:47 -0400 Subject: [PATCH 33/55] Remove stray `end` from `adapters.jl`. --- src/solvers/adapters.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 51e8682d..d1885200 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -44,5 +44,3 @@ function compute!(adapter::EachRegion) compute!(region_iter) return current_region_plan(region_iter) end - -end From 6a8cdb1eacee618b70352333b4fdf38ebabc7de4 Mon Sep 17 00:00:00 2001 From: Jack Dunham <72548217+jack-dunham@users.noreply.github.com> Date: Tue, 14 Oct 2025 10:37:53 -0400 Subject: [PATCH 34/55] Fix typo in docstring of `EachRegion` adapter. Co-authored-by: Matt Fishman --- src/solvers/adapters.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index d1885200..5ad71faa 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -19,7 +19,7 @@ IncrementOnly(adapter::IncrementOnly) = adapter """ struct EachRegion{SweepIterator} <: AbstractNetworkIterator -Adapter that flattens the each region iterator in the parent sweep iterator into a single +Adapter that flattens each region iterator in the parent sweep iterator into a single iterator, returning `region => kwargs`. """ struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator From 9760de1baa1d97cf872b541738e2d540d2e4754b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 10:53:00 -0400 Subject: [PATCH 35/55] Function `reverse_regions` is now more concise. --- src/solvers/region_plans/tdvp_region_plans.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/solvers/region_plans/tdvp_region_plans.jl b/src/solvers/region_plans/tdvp_region_plans.jl index 7b24211c..8ee90086 100644 --- a/src/solvers/region_plans/tdvp_region_plans.jl +++ b/src/solvers/region_plans/tdvp_region_plans.jl @@ -30,12 +30,9 @@ function first_order_sweep(graph, sweep_kwargs; nsites) end function reverse_regions(region_plan) - region_plan = map(reverse(region_plan)) do region_kwargs - region, kwargs = region_kwargs + return map(reverse(region_plan)) do (region, kwargs) return reverse(region) => kwargs end - - return region_plan end # Generate the kwargs for each region. From 26ece7b0b34a0e63a52ff09ad4bcd7ccb29cb97a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 10:53:23 -0400 Subject: [PATCH 36/55] Use explicit imports in `default_kwargs.jl` --- src/solvers/default_kwargs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index 5a5c1fe4..5906c340 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -1,4 +1,4 @@ -using MacroTools +using MacroTools: @capture, splitdef, combinedef, isdef """ default_kwargs(f::Function, args...; kwargs...) From 340d8052d6d9c59f2c7a237ee18134d50cf2fcb8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 10:53:45 -0400 Subject: [PATCH 37/55] Fix test imports and broken tests in `test_iterators.jl`. --- test/solvers/test_default_kwargs.jl | 26 ++++++++----------- test/solvers/test_iterators.jl | 40 ++++++++++++++--------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl index 6368176a..4b074fc5 100644 --- a/test/solvers/test_default_kwargs.jl +++ b/test/solvers/test_default_kwargs.jl @@ -6,8 +6,6 @@ module KwargsTestModule using ITensorNetworks using ITensorNetworks: AbstractProblem, @default_kwargs -export TestProblem, NotOurTestProblem, test_function - struct TestProblem <: AbstractProblem end struct NotOurTestProblem <: AbstractProblem end @@ -21,26 +19,24 @@ end end # KwargsTestModule @testset "Default kwargs" begin - using .KwargsTestModule: TestProblem, NotOurTestProblem, test_function + import .KwargsTestModule - our_iter = RegionIterator(TestProblem(), ["region" => (; test_function_kwargs=(; int=1))], 1) - not_our_iter = RegionIterator(NotOurTestProblem(), ["region" => (; test_function_kwargs=(; int=2))], 1) + our_iter = RegionIterator(KwargsTestModule.TestProblem(), ["region" => (; test_function_kwargs=(; int=1))], 1) + not_our_iter = RegionIterator(KwargsTestModule.NotOurTestProblem(), ["region" => (; test_function_kwargs=(; int=2))], 1) - kw = region_kwargs(test_function, our_iter) + kw = region_kwargs(KwargsTestModule.test_function, our_iter) @test kw == (; int=1) - kw_not = region_kwargs(test_function, not_our_iter) + kw_not = region_kwargs(KwargsTestModule.test_function, not_our_iter) @test kw_not == (; int=2) - @info methods(default_kwargs) - # Test dispatch - @test default_kwargs(test_function, problem(our_iter)) == (; bool=true, int=0) - @test default_kwargs(test_function, problem(our_iter) |> typeof) == (; bool=true, int=0) + @test default_kwargs(KwargsTestModule.test_function, problem(our_iter)) == (; bool=true, int=0) + @test default_kwargs(KwargsTestModule.test_function, problem(our_iter) |> typeof) == (; bool=true, int=0) - @test default_kwargs(test_function, problem(not_our_iter)) == (; bool=false, int=3) - @test default_kwargs(test_function, problem(not_our_iter) |> typeof) == (; bool=false, int=3) + @test default_kwargs(KwargsTestModule.test_function, problem(not_our_iter)) == (; bool=false, int=3) + @test default_kwargs(KwargsTestModule.test_function, problem(not_our_iter) |> typeof) == (; bool=false, int=3) - @test test_function(problem(our_iter); default_kwargs(test_function, problem(our_iter); kw...)...) == (true, 1) - @test test_function(problem(not_our_iter); default_kwargs(test_function, problem(not_our_iter); kw_not...)...) == (false, 2) + @test KwargsTestModule.test_function(problem(our_iter); default_kwargs(KwargsTestModule.test_function, problem(our_iter); kw...)...) == (true, 1) + @test KwargsTestModule.test_function(problem(not_our_iter); default_kwargs(KwargsTestModule.test_function, problem(not_our_iter); kw_not...)...) == (false, 2) end diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl index 5fc77fb8..438f067e 100644 --- a/test/solvers/test_iterators.jl +++ b/test/solvers/test_iterators.jl @@ -1,5 +1,5 @@ using Test: @test, @testset -using ITensorNetworks: SweepIterator, laststep, state, increment!, compute!, eachregion +using ITensorNetworks: SweepIterator, islaststep, state, increment!, compute!, eachregion module TestIteratorUtils @@ -10,7 +10,7 @@ struct TestProblem <: ITensorNetworks.AbstractProblem end ITensorNetworks.region_plan(::TestProblem) = [:a => (; val=1), :b => (; val=2)] function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) - kwargs = ITensorNetworks.current_region_kwargs(iter) + kwargs = ITensorNetworks.region_kwargs(iter) push!(ITensorNetworks.problem(iter).data, kwargs.val) return iter end @@ -46,16 +46,16 @@ end @testset "Iterators" begin - using .TestIteratorUtils: TestIterator, SquareAdapter, TestProblem + import .TestIteratorUtils @testset "`AbstractNetworkIterator` Interface" begin - TI = TestIterator(1, 4, []) + TI = TestIteratorUtils.TestIterator(1, 4, []) - @test !laststep((TI)) + @test !islaststep((TI)) # First iterator should compute only rv, st = iterate(TI) - @test !laststep((TI)) + @test !islaststep((TI)) @test !st @test rv === TI @test length(TI.output) == 1 @@ -64,34 +64,34 @@ end @test !st rv, st = iterate(TI, st) - @test !laststep((TI)) + @test !islaststep((TI)) @test !st @test length(TI.output) == 2 @test state(TI) == 2 @test TI.output == [1, 2] increment!(TI) - @test !laststep((TI)) + @test !islaststep((TI)) @test state(TI) == 3 @test length(TI.output) == 2 @test TI.output == [1, 2] compute!(TI) - @test !laststep((TI)) + @test !islaststep((TI)) @test state(TI) == 3 @test length(TI.output) == 3 @test TI.output == [1, 2, 3] # Final Step iterate(TI, false) - @test laststep((TI)) + @test islaststep((TI)) @test state(TI) == 4 @test length(TI.output) == 4 @test TI.output == [1, 2, 3, 4] @test iterate(TI, false) === nothing - TI = TestIterator(1, 5, []) + TI = TestIteratorUtils.TestIterator(1, 5, []) cb = [] @@ -102,18 +102,18 @@ end @test cb == TI.output end - @test laststep((TI)) + @test islaststep((TI)) @test length(TI.output) == 5 @test length(cb) == 5 @test cb == TI.output - TI = TestIterator(1, 5, []) + TI = TestIteratorUtils.TestIterator(1, 5, []) end @testset "Adapters" begin - TI = TestIterator(1, 5, []) - SA = SquareAdapter(TI) + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) @testset "Generic" begin @@ -125,10 +125,10 @@ end @test state(SA) == i end - @test laststep((SA)) + @test islaststep((SA)) - TI = TestIterator(1, 5, []) - SA = SquareAdapter(TI) + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) SA_c = collect(SA) @@ -139,8 +139,8 @@ end end @testset "EachRegion" begin - prob = TestProblem([]) - prob_region = TestProblem([]) + prob = TestIteratorUtils.TestProblem([]) + prob_region = TestIteratorUtils.TestProblem([]) SI = SweepIterator(prob, 5) SI_region = SweepIterator(prob_region, 5) From 6a33f2908f631dd4383f54f47174097575edf418 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 12:11:22 -0400 Subject: [PATCH 38/55] Rename @default_kwargs -> @define_default_kwargs --- src/solvers/applyexp.jl | 2 +- src/solvers/default_kwargs.jl | 8 ++++---- src/solvers/eigsolve.jl | 2 +- src/solvers/fitting.jl | 2 +- src/solvers/subspace/densitymatrix.jl | 2 +- src/solvers/subspace/subspace.jl | 2 +- test/solvers/test_default_kwargs.jl | 6 +++--- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 0a91067b..e5465b40 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -20,7 +20,7 @@ function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...) return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...) end -@default_kwargs function update!( +@define_default_kwargs function update!( region_iter::RegionIterator{<:ApplyExpProblem}, local_state; nsites, diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index 5906c340..12e46f38 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -14,12 +14,12 @@ end default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...) """ - @default_kwargs + @define_default_kwargs Automatically define a `default_kwargs` method for a given function. This macro should be applied before a function definition: ``` -@default_kwargs astypes = true function f(args...; kwargs...) +@define_default_kwargs astypes = true function f(args...; kwargs...) ... end ``` @@ -30,7 +30,7 @@ default_kwargs(::typeof(f), arg::T; kwargs...) # astypes = false default_kwargs(::typeof(f), arg::Type{<:T}; kwargs...) # astypes = true ``` """ -macro default_kwargs(args...) +macro define_default_kwargs(args...) kwargs = (;) for opt in args if @capture(opt, key_ = val_) @@ -46,7 +46,7 @@ end function default_kwargs_macro(function_def; astypes=true) if !isdef(function_def) throw( - ArgumentError("The @default_kwargs macro must be followed by a function definition") + ArgumentError("The @define_default_kwargs macro must be followed by a function definition") ) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 2f68ca9c..e401581c 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -20,7 +20,7 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) return E end -@default_kwargs function update!( +@define_default_kwargs function update!( region_iter::RegionIterator{<:EigsolveProblem}, local_state; outputlevel=0, diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 668dd0e4..7ceeb1f2 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -44,7 +44,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}) return local_tensor end -@default_kwargs function update!( +@define_default_kwargs function update!( region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel=0 ) F = problem(region_iter) diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index c1f522b6..53f9d8cd 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -1,7 +1,7 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf -@default_kwargs function subspace_expand!( +@define_default_kwargs function subspace_expand!( ::Backend"densitymatrix", region_iter, local_state; north_pass=1 ) prob = problem(region_iter) diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index 0b526604..bf849604 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -1,7 +1,7 @@ using NDTensors: NDTensors using NDTensors.BackendSelection: Backend, @Backend_str -@default_kwargs function subspace_expand!( +@define_default_kwargs function subspace_expand!( region_iter, local_state; subspace_algorithm="nothing" ) backend = Backend(subspace_algorithm) diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl index 4b074fc5..675fb710 100644 --- a/test/solvers/test_default_kwargs.jl +++ b/test/solvers/test_default_kwargs.jl @@ -4,15 +4,15 @@ using ITensorNetworks: AbstractProblem, default_kwargs, RegionIterator, problem, module KwargsTestModule using ITensorNetworks -using ITensorNetworks: AbstractProblem, @default_kwargs +using ITensorNetworks: AbstractProblem, @define_default_kwargs struct TestProblem <: AbstractProblem end struct NotOurTestProblem <: AbstractProblem end -@default_kwargs astypes = true function test_function(::AbstractProblem; bool=false, int=3) +@define_default_kwargs astypes = true function test_function(::AbstractProblem; bool=false, int=3) return bool, int end -@default_kwargs astypes = true function test_function(::TestProblem; bool=true, int=0) +@define_default_kwargs astypes = true function test_function(::TestProblem; bool=true, int=0) return bool, int end From b4bcb937f54acf7da7c778b7cfedd963e8655bf0 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 13:03:28 -0400 Subject: [PATCH 39/55] Remove `astypes` option from `@define_default_kwargs`. --- src/solvers/default_kwargs.jl | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index 12e46f38..88238888 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -19,34 +19,26 @@ default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...) Automatically define a `default_kwargs` method for a given function. This macro should be applied before a function definition: ``` -@define_default_kwargs astypes = true function f(args...; kwargs...) +@define_default_kwargs function f(arg1::T1, arg2::T2, ...; kwargs...) ... end ``` -If `astypes = true` then the `default_kwargs` method is defined in the -type domain with respect to `args`, i.e. +The defined `default_kwargs` method takes the form ``` -default_kwargs(::typeof(f), arg::T; kwargs...) # astypes = false -default_kwargs(::typeof(f), arg::Type{<:T}; kwargs...) # astypes = true +default_kwargs(::typeof(f), arg1::T1, arg2::T2, ...; kwargs...) ``` +i.e. the function signature mirrors that of the function signature of `f`. """ -macro define_default_kwargs(args...) - kwargs = (;) - for opt in args - if @capture(opt, key_ = val_) - kwargs = merge(kwargs, NamedTuple{(key,)}((val,))) - elseif opt === last(args) - return default_kwargs_macro(opt; kwargs...) - else - throw(ArgumentError("Unknown expression object")) - end - end +macro define_default_kwargs(function_def) + return default_kwargs_macro(function_def) end -function default_kwargs_macro(function_def; astypes=true) +function default_kwargs_macro(function_def) if !isdef(function_def) throw( - ArgumentError("The @define_default_kwargs macro must be followed by a function definition") + ArgumentError( + "The @define_default_kwargs macro must be followed by a function definition" + ), ) end @@ -78,14 +70,7 @@ function default_kwargs_macro(function_def; astypes=true) return kw end - # Promote to the type domain if wanted new_ex[:args] = convert(Vector{Any}, ex[:args]) - if astypes - new_ex[:args] = map(new_ex[:args]) do arg - @capture(arg, name_::T_) - return :($(name)::Type{<:$T}) - end - end new_ex[:name] = :(ITensorNetworks.default_kwargs) new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name]))))) From 624f9646bdf6e947a72190f139183f32834403db Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 13:03:54 -0400 Subject: [PATCH 40/55] Update `default_kwargs` tests. --- test/solvers/test_default_kwargs.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl index 675fb710..be422ffc 100644 --- a/test/solvers/test_default_kwargs.jl +++ b/test/solvers/test_default_kwargs.jl @@ -9,10 +9,10 @@ using ITensorNetworks: AbstractProblem, @define_default_kwargs struct TestProblem <: AbstractProblem end struct NotOurTestProblem <: AbstractProblem end -@define_default_kwargs astypes = true function test_function(::AbstractProblem; bool=false, int=3) +@define_default_kwargs function test_function(::AbstractProblem; bool=false, int=3) return bool, int end -@define_default_kwargs astypes = true function test_function(::TestProblem; bool=true, int=0) +@define_default_kwargs function test_function(::TestProblem; bool=true, int=0) return bool, int end @@ -31,10 +31,8 @@ end # KwargsTestModule # Test dispatch @test default_kwargs(KwargsTestModule.test_function, problem(our_iter)) == (; bool=true, int=0) - @test default_kwargs(KwargsTestModule.test_function, problem(our_iter) |> typeof) == (; bool=true, int=0) @test default_kwargs(KwargsTestModule.test_function, problem(not_our_iter)) == (; bool=false, int=3) - @test default_kwargs(KwargsTestModule.test_function, problem(not_our_iter) |> typeof) == (; bool=false, int=3) @test KwargsTestModule.test_function(problem(our_iter); default_kwargs(KwargsTestModule.test_function, problem(our_iter); kw...)...) == (true, 1) @test KwargsTestModule.test_function(problem(not_our_iter); default_kwargs(KwargsTestModule.test_function, problem(not_our_iter); kw_not...)...) == (false, 2) From bd35f0915b3ac438ecc94528621a45a92154e90e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 13:18:02 -0400 Subject: [PATCH 41/55] Add `sweep_solve` method for `EachRegion` adapter. A callback occuring at each region can be passed using the `do` syntax using this method. --- src/solvers/sweep_solve.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index 53343a65..beda1def 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -24,6 +24,23 @@ end # I suspect that `sweep_callback` is the more commonly used callback, so allow this to # be set using the `do` syntax. -function sweep_solve(sweep_callback, sweep_iterator; kwargs...) - return sweep_solve(sweep_iterator; sweep_callback, kwargs...) +function sweep_solve( + sweep_callback, sweep_iterator; region_callback=default_region_callback +) + return sweep_solve(sweep_iterator; sweep_callback, region_callback) +end + +function sweep_solve( + each_region_iterator::EachRegion; region_callback=default_region_callback +) + return sweep_solve(region_callback, each_region_iterator) +end +function sweep_solve(region_callback, each_region_iterator::EachRegion) + for _ in each_region_iterator + # I don't think it is obvious what object this particular callback should take, + # but for now be consistant and pass the parent sweep iterator. + sweep_iterator = each_region_iterator.parent + region_callback(sweep_iterator) + end + return problem(each_region_iterator) end From 0b5314dc4e8a404963766a216a6ed13e56c1544a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 14:18:29 -0400 Subject: [PATCH 42/55] Add `@with_kwargs` macro which automatically splats `default_kwargs` into a function call. --- src/solvers/default_kwargs.jl | 17 +++++++++++++++++ test/solvers/test_default_kwargs.jl | 9 ++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl index 88238888..c58925ad 100644 --- a/src/solvers/default_kwargs.jl +++ b/src/solvers/default_kwargs.jl @@ -96,3 +96,20 @@ function default_kwargs_macro(function_def) return rv end + +macro with_defaults(call_expr) + if @capture(call_expr, (func_(args__; kwargs__)) | (func_(args__))) + if isnothing(kwargs) + kwargs = [] + end + rv = quote + $(esc(func))( + $(esc.(args)...); + default_kwargs($(esc(func)), $(esc.(args)...); $(esc.(kwargs)...))..., + ) + end + return rv + else + throw(ArgumentError("unable to parse function call expression, try including brackets in the macro call.")) + end +end diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl index be422ffc..010b0edd 100644 --- a/test/solvers/test_default_kwargs.jl +++ b/test/solvers/test_default_kwargs.jl @@ -1,5 +1,5 @@ using Test: @test, @testset -using ITensorNetworks: AbstractProblem, default_kwargs, RegionIterator, problem, region_kwargs +using ITensorNetworks: AbstractProblem, default_kwargs, RegionIterator, problem, region_kwargs, @with_defaults module KwargsTestModule @@ -37,4 +37,11 @@ end # KwargsTestModule @test KwargsTestModule.test_function(problem(our_iter); default_kwargs(KwargsTestModule.test_function, problem(our_iter); kw...)...) == (true, 1) @test KwargsTestModule.test_function(problem(not_our_iter); default_kwargs(KwargsTestModule.test_function, problem(not_our_iter); kw_not...)...) == (false, 2) + @test @with_defaults(KwargsTestModule.test_function(problem(our_iter))) == (true, 0) + @test @with_defaults(KwargsTestModule.test_function(problem(our_iter);)) == (true, 0) + @test @with_defaults(KwargsTestModule.test_function(problem(our_iter); bool = false)) == (false, 0) + + let testval = @with_defaults KwargsTestModule.test_function(problem(our_iter); int = 3) + @test testval == (true, 3) + end end From a58ec9267cc5b36540b1d59f9e2b2f51372222e2 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 14 Oct 2025 14:28:40 -0400 Subject: [PATCH 43/55] Make use of `@with_kwargs` macro make code more concise. --- src/solvers/iterators.jl | 23 +++-------------------- src/solvers/subspace/subspace.jl | 12 ++---------- 2 files changed, 5 insertions(+), 30 deletions(-) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 69744d2a..b0bcfc94 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -99,27 +99,10 @@ function increment!(region_iter::RegionIterator) return region_iter end -# Purely for our convenience: -function extract!_kwargs(iter) - f = extract! - kwargs = region_kwargs(f, iter) - return default_kwargs(f, iter; kwargs...) -end -function update!_kwargs(iter, local_state) - f = update! - kwargs = region_kwargs(f, iter) - return default_kwargs(f, iter, local_state; kwargs...) -end -function insert!_kwargs(iter, local_state) - f = insert! - kwargs = region_kwargs(f, iter) - return default_kwargs(f, iter, local_state; kwargs...) -end - function compute!(iter::RegionIterator) - local_state = extract!(iter; extract!_kwargs(iter)...) - local_state = update!(iter, local_state; update!_kwargs(iter, local_state)...) - insert!(iter, local_state; insert!_kwargs(iter, local_state)...) + local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...) + local_state = @with_defaults update!(iter, local_state; region_kwargs(update!, iter)...) + @with_defaults insert!(iter, local_state; region_kwargs(insert!, iter)...) return iter end diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index bf849604..ef2b50c0 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -10,16 +10,8 @@ using NDTensors.BackendSelection: Backend, @Backend_str return local_state end - subspace_expand!_kwargs = default_kwargs( - subspace_expand!, - backend, - region_iter, - local_state; - region_kwargs(subspace_expand!, region_iter)..., - ) - - local_state = subspace_expand!( - backend, region_iter, local_state; subspace_expand!_kwargs... + local_state = @with_defaults subspace_expand!( + backend, region_iter, local_state; region_kwargs(subspace_expand!, region_iter)... ) return local_state From b72a08f5dd738fab30749e9c881eede1c63eba80 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 09:37:16 -0400 Subject: [PATCH 44/55] The fallback default callback functions now no longer accept `kwargs...`. No kwargs get passed to these callbacks anyway so this is cosmetic change. --- src/solvers/sweep_solve.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index beda1def..c469ac63 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -1,10 +1,7 @@ -function default_region_callback(sweep_iterator; kwargs...) - return sweep_iterator -end -function default_sweep_callback(sweep_iterator; kwargs...) - return sweep_iterator -end +default_region_callback(sweep_iterator) = sweep_iterator +default_sweep_callback(sweep_iterator) = sweep_iterator + # In this implementation the function `sweep_solve` is essentially just a wrapper around # the iterate interface that allows one to pass callbacks. function sweep_solve( From c5de5c47b569b51b24149dc2623158fee4ed149b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 17:48:55 -0400 Subject: [PATCH 45/55] Test fix: tests founds in sub-directories are now actually ran when including `runtests.jl`. --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b8..fb2673d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ end @time begin # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) + for testgroup in filter(f -> isdir(joinpath(@__DIR__, f)), readdir(@__DIR__)) if GROUP == "ALL" || GROUP == uppercase(testgroup) groupdir = joinpath(@__DIR__, testgroup) for file in filter(istestfile, readdir(groupdir)) From 2788057b9f72af4b521906e46c024d56c627be4e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 17:52:12 -0400 Subject: [PATCH 46/55] Skip broken tests for now --- test/{test_ttn_contract.jl => _test_ttn_contract.jl} | 2 ++ test/{test_ttn_dmrg.jl => _test_ttn_dmrg.jl} | 0 test/{test_ttn_dmrg_x.jl => _test_ttn_dmrg_x.jl} | 0 test/{test_ttn_linsolve.jl => _test_ttn_linsolve.jl} | 0 test/{test_ttn_tdvp.jl => _test_ttn_tdvp.jl} | 0 ..._tdvp_time_dependent.jl => _test_ttn_tdvp_time_dependent.jl} | 0 6 files changed, 2 insertions(+) rename test/{test_ttn_contract.jl => _test_ttn_contract.jl} (99%) rename test/{test_ttn_dmrg.jl => _test_ttn_dmrg.jl} (100%) rename test/{test_ttn_dmrg_x.jl => _test_ttn_dmrg_x.jl} (100%) rename test/{test_ttn_linsolve.jl => _test_ttn_linsolve.jl} (100%) rename test/{test_ttn_tdvp.jl => _test_ttn_tdvp.jl} (100%) rename test/{test_ttn_tdvp_time_dependent.jl => _test_ttn_tdvp_time_dependent.jl} (100%) diff --git a/test/test_ttn_contract.jl b/test/_test_ttn_contract.jl similarity index 99% rename from test/test_ttn_contract.jl rename to test/_test_ttn_contract.jl index 500826c3..f3cdbc48 100644 --- a/test/test_ttn_contract.jl +++ b/test/_test_ttn_contract.jl @@ -22,6 +22,8 @@ using NamedGraphs.NamedGraphGenerators: named_comb_tree using StableRNGs: StableRNG using Test: @test, @test_broken, @testset +# These tests are broken currently + @testset "Contract MPO" begin N = 20 s = siteinds("S=1/2", N) diff --git a/test/test_ttn_dmrg.jl b/test/_test_ttn_dmrg.jl similarity index 100% rename from test/test_ttn_dmrg.jl rename to test/_test_ttn_dmrg.jl diff --git a/test/test_ttn_dmrg_x.jl b/test/_test_ttn_dmrg_x.jl similarity index 100% rename from test/test_ttn_dmrg_x.jl rename to test/_test_ttn_dmrg_x.jl diff --git a/test/test_ttn_linsolve.jl b/test/_test_ttn_linsolve.jl similarity index 100% rename from test/test_ttn_linsolve.jl rename to test/_test_ttn_linsolve.jl diff --git a/test/test_ttn_tdvp.jl b/test/_test_ttn_tdvp.jl similarity index 100% rename from test/test_ttn_tdvp.jl rename to test/_test_ttn_tdvp.jl diff --git a/test/test_ttn_tdvp_time_dependent.jl b/test/_test_ttn_tdvp_time_dependent.jl similarity index 100% rename from test/test_ttn_tdvp_time_dependent.jl rename to test/_test_ttn_tdvp_time_dependent.jl From 33b9e28789b8a77863ee367b842d062b11bb50a7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 18:03:56 -0400 Subject: [PATCH 47/55] Rename `sweep_solve` -> `sweep_solve!` to obey convention --- src/solvers/applyexp.jl | 2 +- src/solvers/eigsolve.jl | 2 +- src/solvers/fitting.jl | 2 +- src/solvers/sweep_solve.jl | 12 ++++++------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index e5465b40..592ae64c 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -91,7 +91,7 @@ function applyexp( ] sweep_iter = SweepIterator(init_prob, kws_array) - converged_prob = sweep_solve(sweep_callback, sweep_iter) + converged_prob = sweep_solve!(sweep_callback, sweep_iter) return state(converged_prob) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index e401581c..df2851a4 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -64,7 +64,7 @@ function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_ state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...) - prob = sweep_solve(sweep_iter) + prob = sweep_solve!(sweep_iter) return eigenvalue(prob), state(prob) end diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 7ceeb1f2..18e8c54a 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -90,7 +90,7 @@ function fit_tensornetwork( kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps] sweep_iter = SweepIterator(init_prob, kwargs_array) - converged_prob = sweep_solve(sweep_iter) + converged_prob = sweep_solve!(sweep_iter) return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob)) end diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index c469ac63..b30e3d21 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -4,7 +4,7 @@ default_sweep_callback(sweep_iterator) = sweep_iterator # In this implementation the function `sweep_solve` is essentially just a wrapper around # the iterate interface that allows one to pass callbacks. -function sweep_solve( +function sweep_solve!( sweep_iterator; sweep_callback=default_sweep_callback, region_callback=default_region_callback, @@ -21,18 +21,18 @@ end # I suspect that `sweep_callback` is the more commonly used callback, so allow this to # be set using the `do` syntax. -function sweep_solve( +function sweep_solve!( sweep_callback, sweep_iterator; region_callback=default_region_callback ) - return sweep_solve(sweep_iterator; sweep_callback, region_callback) + return sweep_solve!(sweep_iterator; sweep_callback, region_callback) end -function sweep_solve( +function sweep_solve!( each_region_iterator::EachRegion; region_callback=default_region_callback ) - return sweep_solve(region_callback, each_region_iterator) + return sweep_solve!(region_callback, each_region_iterator) end -function sweep_solve(region_callback, each_region_iterator::EachRegion) +function sweep_solve!(region_callback, each_region_iterator::EachRegion) for _ in each_region_iterator # I don't think it is obvious what object this particular callback should take, # but for now be consistant and pass the parent sweep iterator. From dedd82ed8ff1d32115f8c26e40e957683f096ba4 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 18:06:03 -0400 Subject: [PATCH 48/55] The `EachRegion` adapter now returns itself from `iterate` instead of the region plan. This is to keep it consistant with other examples of `AbstractNetworkIterator`. --- src/solvers/adapters.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solvers/adapters.jl b/src/solvers/adapters.jl index 5ad71faa..e1c58d15 100644 --- a/src/solvers/adapters.jl +++ b/src/solvers/adapters.jl @@ -20,7 +20,7 @@ IncrementOnly(adapter::IncrementOnly) = adapter struct EachRegion{SweepIterator} <: AbstractNetworkIterator Adapter that flattens each region iterator in the parent sweep iterator into a single -iterator, returning `region => kwargs`. +iterator. """ struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator parent::SI @@ -42,5 +42,5 @@ end function compute!(adapter::EachRegion) region_iter = region_iterator(adapter.parent) compute!(region_iter) - return current_region_plan(region_iter) + return adapter end From d39f09ed5748badad7c97ad9532980e91511a15e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 18:08:02 -0400 Subject: [PATCH 49/55] The `sweep_solve!` function now always returns the type of the input iterator. --- src/solvers/sweep_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solvers/sweep_solve.jl b/src/solvers/sweep_solve.jl index b30e3d21..f67ce4de 100644 --- a/src/solvers/sweep_solve.jl +++ b/src/solvers/sweep_solve.jl @@ -16,7 +16,7 @@ function sweep_solve!( end sweep_callback(sweep_iterator) end - return problem(sweep_iterator) + return sweep_iterator end # I suspect that `sweep_callback` is the more commonly used callback, so allow this to @@ -39,5 +39,5 @@ function sweep_solve!(region_callback, each_region_iterator::EachRegion) sweep_iterator = each_region_iterator.parent region_callback(sweep_iterator) end - return problem(each_region_iterator) + return each_region_iterator end From 3f5c97c88b250e64fde45508c15f7def00ca0cad Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 15 Oct 2025 18:28:42 -0400 Subject: [PATCH 50/55] Mutating functions now return the first argument before any additional data. --- src/solvers/applyexp.jl | 8 +++++--- src/solvers/eigsolve.jl | 4 ++-- src/solvers/extract.jl | 4 ++-- src/solvers/fitting.jl | 12 ++++++++---- src/solvers/insert.jl | 2 +- src/solvers/iterators.jl | 6 ++++-- src/solvers/subspace/densitymatrix.jl | 12 ++++++------ src/solvers/subspace/subspace.jl | 6 +++--- 8 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 592ae64c..9c080cfb 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -29,7 +29,9 @@ end ) prob = problem(region_iter) - iszero(abs(exponent_step)) && return local_state + if iszero(abs(exponent_step)) + return region_iter, local_state + end solver_kwargs = region_kwargs(solver, region_iter) @@ -54,7 +56,7 @@ end prob.current_exponent += exponent_step - return local_state + return region_iter, local_state end function default_sweep_callback( @@ -91,7 +93,7 @@ function applyexp( ] sweep_iter = SweepIterator(init_prob, kws_array) - converged_prob = sweep_solve!(sweep_callback, sweep_iter) + converged_prob = problem(sweep_solve!(sweep_callback, sweep_iter)) return state(converged_prob) end diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index df2851a4..996979a4 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -37,7 +37,7 @@ end if outputlevel >= 2 @printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob)) end - return local_state + return region_iter, local_state end function default_sweep_callback( @@ -64,7 +64,7 @@ function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_ state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...) - prob = sweep_solve!(sweep_iter) + prob = problem(sweep_solve!(sweep_iter)) return eigenvalue(prob), state(prob) end diff --git a/src/solvers/extract.jl b/src/solvers/extract.jl index 526f58da..629b70f2 100644 --- a/src/solvers/extract.jl +++ b/src/solvers/extract.jl @@ -7,10 +7,10 @@ function extract!(region_iter::RegionIterator; subspace_algorithm="nothing") prob.state = psi - local_state = subspace_expand!(region_iter, local_state; subspace_algorithm) + _, local_state = subspace_expand!(region_iter, local_state; subspace_algorithm) shifted_operator = position(operator(prob), state(prob), region) prob.operator = shifted_operator - return local_state + return region_iter, local_state end diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index 18e8c54a..d0dfddb8 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -41,7 +41,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}) prob.state = tn prob.gauge_region = region - return local_tensor + return region_iter, local_tensor end @define_default_kwargs function update!( @@ -58,7 +58,7 @@ end @printf(" Region %s: squared overlap = %.12f\n", region, overlap(F)) end - return local_tensor + return region_iter, local_tensor end function region_plan(F::FittingProblem; nsites, sweep_kwargs...) @@ -90,7 +90,7 @@ function fit_tensornetwork( kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps] sweep_iter = SweepIterator(init_prob, kwargs_array) - converged_prob = sweep_solve!(sweep_iter) + converged_prob = problem(sweep_solve!(sweep_iter)) return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob)) end @@ -109,7 +109,11 @@ end #end function ITensors.apply( - A::ITensorNetwork, x::ITensorNetwork; maxdim=typemax(Int), cutoff=0.0, sweep_kwargs... + A::AbstractITensorNetwork, + x::AbstractITensorNetwork; + maxdim=typemax(Int), + cutoff=0.0, + sweep_kwargs..., ) init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim) overlap_network = inner_network(x, A, init_state) diff --git a/src/solvers/insert.jl b/src/solvers/insert.jl index 7ab4bdb2..87ffaf6d 100644 --- a/src/solvers/insert.jl +++ b/src/solvers/insert.jl @@ -28,5 +28,5 @@ function insert!(region_iter, local_tensor; normalize=false, set_orthogonal_regi prob.state = psi - return prob + return region_iter end diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index b0bcfc94..b7347bb9 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -100,8 +100,10 @@ function increment!(region_iter::RegionIterator) end function compute!(iter::RegionIterator) - local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...) - local_state = @with_defaults update!(iter, local_state; region_kwargs(update!, iter)...) + _, local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...) + _, local_state = @with_defaults update!( + iter, local_state; region_kwargs(update!, iter)... + ) @with_defaults insert!(iter, local_state; region_kwargs(insert!, iter)...) return iter diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 53f9d8cd..94e047ab 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -10,23 +10,23 @@ using Printf: @printf psi = copy(state(prob)) prev_vertex_set = setdiff(pos(operator(prob)), region) - (length(prev_vertex_set) != 1) && return local_state + (length(prev_vertex_set) != 1) && return region_iter, local_state prev_vertex = only(prev_vertex_set) A = psi[prev_vertex] next_vertices = filter(v -> (hascommoninds(psi[v], A)), region) - isempty(next_vertices) && return local_state + isempty(next_vertices) && return region_iter, local_state next_vertex = only(next_vertices) C = psi[next_vertex] a = commonind(A, C) - isnothing(a) && return local_state + isnothing(a) && return region_iter, local_state basis_size = prod(dim.(uniqueinds(A, C))) expanded_maxdim = compute_expansion( dim(a), basis_size; region_kwargs(compute_expansion, region_iter)... ) - expanded_maxdim <= 0 && return local_state + expanded_maxdim <= 0 && return region_iter, local_state envs = environments(operator(prob)) H = operator(operator(prob)) @@ -50,7 +50,7 @@ using Printf: @printf end if norm(dag(U) * A) > 1E-10 @printf("Warning: |U*A| = %.3E in subspace expansion\n", norm(dag(U) * A)) - return local_state + return region_iter, local_state end Ax, ax = directsum(A => a, U => commonind(U, D)) @@ -61,5 +61,5 @@ using Printf: @printf prob.state = psi - return local_state + return region_iter, local_state end diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index ef2b50c0..c6177eb3 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -7,14 +7,14 @@ using NDTensors.BackendSelection: Backend, @Backend_str backend = Backend(subspace_algorithm) if backend isa Backend"nothing" - return local_state + return region_iter, local_state end - local_state = @with_defaults subspace_expand!( + _, local_state = @with_defaults subspace_expand!( backend, region_iter, local_state; region_kwargs(subspace_expand!, region_iter)... ) - return local_state + return region_iter, local_state end function subspace_expand!(backend, region_iterator, local_state; kwargs...) From 7ad31387e7ac47d970813fd77bbe9e1d669a8e90 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 16 Oct 2025 15:47:31 -0400 Subject: [PATCH 51/55] Remove depreciated `solvers` code and tests from old interface Revert "Skip broken tests correctly" This reverts commit c675c6f64c9a13b46dcb36d292af12482ffff5c6. Delete depreciated code To be reintroduced in ITensorNetworksNext Squash of ^ This removes `solvers/previous_interfaces` and the broken test files --- src/solvers/previous_interfaces/contract.jl | 86 --- src/solvers/previous_interfaces/dmrg.jl | 23 - src/solvers/previous_interfaces/dmrg_x.jl | 19 - src/solvers/previous_interfaces/linsolve.jl | 49 -- src/solvers/previous_interfaces/tdvp.jl | 154 ----- test/_test_ttn_contract.jl | 154 ----- test/_test_ttn_dmrg.jl | 328 ---------- test/_test_ttn_dmrg_x.jl | 70 --- test/_test_ttn_linsolve.jl | 48 -- test/_test_ttn_tdvp.jl | 663 -------------------- test/_test_ttn_tdvp_time_dependent.jl | 236 ------- 11 files changed, 1830 deletions(-) delete mode 100644 src/solvers/previous_interfaces/contract.jl delete mode 100644 src/solvers/previous_interfaces/dmrg.jl delete mode 100644 src/solvers/previous_interfaces/dmrg_x.jl delete mode 100644 src/solvers/previous_interfaces/linsolve.jl delete mode 100644 src/solvers/previous_interfaces/tdvp.jl delete mode 100644 test/_test_ttn_contract.jl delete mode 100644 test/_test_ttn_dmrg.jl delete mode 100644 test/_test_ttn_dmrg_x.jl delete mode 100644 test/_test_ttn_linsolve.jl delete mode 100644 test/_test_ttn_tdvp.jl delete mode 100644 test/_test_ttn_tdvp_time_dependent.jl diff --git a/src/solvers/previous_interfaces/contract.jl b/src/solvers/previous_interfaces/contract.jl deleted file mode 100644 index 00e5c4d6..00000000 --- a/src/solvers/previous_interfaces/contract.jl +++ /dev/null @@ -1,86 +0,0 @@ -using Graphs: nv, vertices -using ITensors: ITensors, sim -using ITensors.NDTensors: Algorithm, @Algorithm_str, contract -using NamedGraphs: vertextype - -function sum_contract( - ::Algorithm"fit", - tns::Vector{<:Tuple{<:AbstractTTN,<:AbstractTTN}}; - init, - nsites=2, - nsweeps=1, - cutoff=eps(), - updater=contract_updater, - kwargs..., -) - tn1s = first.(tns) - tn2s = last.(tns) - ns = nv.(tn1s) - n = first(ns) - any(ns .!= nv.(tn2s)) && throw( - DimensionMismatch("Number of sites operator ($n) and state ($(nv(tn2))) do not match") - ) - any(ns .!= n) && - throw(DimensionMismatch("Number of sites in different operators ($n) do not match")) - # ToDo: Write test for single-vertex ttn, this implementation has not been tested. - if n == 1 - res = 0 - for (tn1, tn2) in zip(tn1s, tn2s) - v = only(vertices(tn2)) - res += tn1[v] * tn2[v] - end - return typeof(tn2)([res]) - end - - # In case `tn1` and `tn2` have the same internal indices - operator = ProjOuterProdTTN{vertextype(first(tn1s))}[] - for (tn1, tn2) in zip(tn1s, tn2s) - tn1 = sim(linkinds, tn1) - - # In case `init` and `tn2` have the same internal indices - init = sim(linkinds, init) - push!(operator, ProjOuterProdTTN(tn2, tn1)) - end - operator = isone(length(operator)) ? only(operator) : ProjTTNSum(operator) - #ToDo: remove? - # Fix site and link inds of init - ## init = deepcopy(init) - ## init = sim(linkinds, init) - ## for v in vertices(tn2) - ## replaceinds!( - ## init[v], siteinds(init, v), uniqueinds(siteinds(tn1, v), siteinds(tn2, v)) - ## ) - ## end - - return alternating_update(operator, init; nsweeps, nsites, updater, cutoff, kwargs...) -end - -function NDTensors.contract( - a::Algorithm"fit", tn1::AbstractTTN, tn2::AbstractTTN; kwargs... -) - return sum_contract(a, [(tn1, tn2)]; kwargs...) -end - -""" -Overload of `ITensors.contract`. -""" -function NDTensors.contract(tn1::AbstractTTN, tn2::AbstractTTN; alg="fit", kwargs...) - return contract(Algorithm(alg), tn1, tn2; kwargs...) -end - -""" -Overload of `ITensors.apply`. -""" -function ITensors.apply(tn1::AbstractTTN, tn2::AbstractTTN; init, kwargs...) - init = init' - tn12 = contract(tn1, tn2; init, kwargs...) - return replaceprime(tn12, 1 => 0) -end - -function sum_apply( - tns::Vector{<:Tuple{<:AbstractTTN,<:AbstractTTN}}; alg="fit", init, kwargs... -) - init = init' - tn12 = sum_contract(Algorithm(alg), tns; init, kwargs...) - return replaceprime(tn12, 1 => 0) -end diff --git a/src/solvers/previous_interfaces/dmrg.jl b/src/solvers/previous_interfaces/dmrg.jl deleted file mode 100644 index 1acbde35..00000000 --- a/src/solvers/previous_interfaces/dmrg.jl +++ /dev/null @@ -1,23 +0,0 @@ -using KrylovKit: KrylovKit - -function dmrg( - operator, - init_state; - nsweeps, - nsites=2, - updater=eigsolve_updater, - (region_observer!)=nothing, - kwargs..., -) - eigvals_ref = Ref{Any}() - region_observer! = compose_observers( - region_observer!, ValuesObserver((; eigvals=eigvals_ref)) - ) - state = alternating_update( - operator, init_state; nsweeps, nsites, updater, region_observer!, kwargs... - ) - eigval = only(eigvals_ref[]) - return eigval, state -end - -KrylovKit.eigsolve(H, init::AbstractTTN; kwargs...) = dmrg(H, init; kwargs...) diff --git a/src/solvers/previous_interfaces/dmrg_x.jl b/src/solvers/previous_interfaces/dmrg_x.jl deleted file mode 100644 index 7ab9d8cd..00000000 --- a/src/solvers/previous_interfaces/dmrg_x.jl +++ /dev/null @@ -1,19 +0,0 @@ -function dmrg_x( - operator, - init_state::AbstractTTN; - nsweeps, - nsites=2, - updater=dmrg_x_updater, - (region_observer!)=nothing, - kwargs..., -) - eigvals_ref = Ref{Any}() - region_observer! = compose_observers( - region_observer!, ValuesObserver((; eigvals=eigvals_ref)) - ) - state = alternating_update( - operator, init_state; nsweeps, nsites, updater, region_observer!, kwargs... - ) - eigval = only(eigvals_ref[]) - return eigval, state -end diff --git a/src/solvers/previous_interfaces/linsolve.jl b/src/solvers/previous_interfaces/linsolve.jl deleted file mode 100644 index acd93cef..00000000 --- a/src/solvers/previous_interfaces/linsolve.jl +++ /dev/null @@ -1,49 +0,0 @@ -using DocStringExtensions: TYPEDSIGNATURES -using KrylovKit: KrylovKit - -""" -$(TYPEDSIGNATURES) - -Compute a solution x to the linear system: - -(a₀ + a₁ * A)*x = b - -using starting guess x₀. Leaving a₀, a₁ -set to their default values solves the -system A*x = b. - -To adjust the balance between accuracy of solution -and speed of the algorithm, it is recommed to first try -adjusting the `solver_tol` keyword argument descibed below. - -Keyword arguments: - - `ishermitian::Bool=false` - should set to true if the MPO A is Hermitian - - `solver_krylovdim::Int=30` - max number of Krylov vectors to build on each solver iteration - - `solver_maxiter::Int=100` - max number outer iterations (restarts) to do in the solver step - - `solver_tol::Float64=1E-14` - tolerance or error goal of the solver - -Overload of `KrylovKit.linsolve`. -""" -function KrylovKit.linsolve( - A::AbstractTTN, - b::AbstractTTN, - x₀::AbstractTTN, - a₀::Number=0, - a₁::Number=1; - updater=linsolve_updater, - nsites=2, - nsweeps, #it makes sense to require this to be defined - updater_kwargs=(;), - kwargs..., -) - updater_kwargs = (; a₀, a₁, updater_kwargs...) - error("`linsolve` for TTN not yet implemented.") - - # TODO: Define `itensornetwork_cache` - # TODO: Define `linsolve_cache` - - P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b)) - return alternating_update( - P, x₀; nsweeps, nsites, updater=linsolve_updater, updater_kwargs, kwargs... - ) -end diff --git a/src/solvers/previous_interfaces/tdvp.jl b/src/solvers/previous_interfaces/tdvp.jl deleted file mode 100644 index 7a58fe1b..00000000 --- a/src/solvers/previous_interfaces/tdvp.jl +++ /dev/null @@ -1,154 +0,0 @@ -using NamedGraphs.GraphsExtensions: GraphsExtensions - -#ToDo: Cleanup _compute_nsweeps, maybe restrict flexibility to simplify code -function _compute_nsweeps(nsweeps::Int, t::Number, time_step::Number) - return error("Cannot specify both nsweeps and time_step in tdvp") -end - -function _compute_nsweeps(nsweeps::Nothing, t::Number, time_step::Nothing) - return 1, [t] -end - -function _compute_nsweeps(nsweeps::Nothing, t::Number, time_step::Number) - @assert isfinite(time_step) && abs(time_step) > 0.0 - nsweeps = convert(Int, ceil(abs(t / time_step))) - if !(nsweeps * time_step ≈ t) - println("Time that will be reached = nsweeps * time_step = ", nsweeps * time_step) - println("Requested total time t = ", t) - error("Time step $time_step not commensurate with total time t=$t") - end - return nsweeps, extend_or_truncate(time_step, nsweeps) -end - -function _compute_nsweeps(nsweeps::Int, t::Number, time_step::Nothing) - time_step = extend_or_truncate(t / nsweeps, nsweeps) - return nsweeps, time_step -end - -function _compute_nsweeps(nsweeps, t::Number, time_step::Vector) - diff_time = t - sum(time_step) - - isnothing(nsweeps) - if isnothing(nsweeps) - #extend_or_truncate time_step to reach final time t - last_time_step = last(time_step) - nsweepstopad = Int(ceil(abs(diff_time / last_time_step))) - if !(sum(time_step) + nsweepstopad * last_time_step ≈ t) - println( - "Time that will be reached = nsweeps * time_step = ", - sum(time_step) + nsweepstopad * last_time_step, - ) - println("Requested total time t = ", t) - error("Time step $time_step not commensurate with total time t=$t") - end - time_step = extend_or_truncate(time_step, length(time_step) + nsweepstopad) - nsweeps = length(time_step) - else - nsweepstopad = nsweeps - length(time_step) - if abs(diff_time) < eps() && !iszero(nsweepstopad) - warn( - "A vector of timesteps that sums up to total time t=$t was supplied, - but its length (=$(length(time_step))) does not agree with supplied number of sweeps (=$(nsweeps)).", - ) - return length(time_step), time_step - end - remaining_time_step = diff_time / nsweepstopad - append!(time_step, extend_or_truncate(remaining_time_step, nsweepstopad)) - end - return nsweeps, time_step -end - -function sub_time_steps(order) - if order == 1 - return [1.0] - elseif order == 2 - return [1 / 2, 1 / 2] - elseif order == 4 - s = 1.0 / (2 - 2^(1 / 3)) - return [s / 2, s / 2, (1 - 2 * s) / 2, (1 - 2 * s) / 2, s / 2, s / 2] - else - error("Trotter order of $order not supported") - end -end - -""" - tdvp(operator::TTN, t::Number, init_state::TTN; kwargs...) - -Use the time dependent variational principle (TDVP) algorithm -to approximately compute `exp(operator*t)*init_state` using an efficient algorithm based -on alternating optimization of the state tensors and local Krylov -exponentiation of operator. The time parameter `t` can be a real or complex number. - -Returns: -* `state` - time-evolved state - -Optional keyword arguments: -* `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run. -* `nsteps::Integer` - evolve by the requested total time `t` by performing `nsteps` of the TDVP algorithm. More steps can result in more accurate results but require more computational time to run. (Note that only one of the `time_step` or `nsteps` parameters can be provided, not both.) -* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output -* `observer` - object implementing the Observer interface which can perform measurements and stop early -* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations -""" -function tdvp( - operator, - t::Number, - init_state::AbstractTTN; - t_start=0.0, - time_step=nothing, - nsites=2, - nsweeps=nothing, - order::Integer=2, - outputlevel=default_outputlevel(), - region_printer=nothing, - sweep_printer=nothing, - (sweep_observer!)=nothing, - (region_observer!)=nothing, - root_vertex=GraphsExtensions.default_root_vertex(init_state), - reverse_step=true, - extracter_kwargs=(;), - extracter=default_extracter(), # ToDo: extracter could be inside extracter_kwargs, at the cost of having to extract it in region_update - updater_kwargs=(;), - updater=exponentiate_updater, - inserter_kwargs=(;), - inserter=default_inserter(), - transform_operator_kwargs=(;), - transform_operator=default_transform_operator(), - kwargs..., -) - # move slurped kwargs into inserter - inserter_kwargs = (; inserter_kwargs..., kwargs...) - # process nsweeps and time_step - nsweeps, time_step = _compute_nsweeps(nsweeps, t, time_step) - t_evolved = t_start .+ cumsum(time_step) - sweep_plans = default_sweep_plans( - nsweeps, - init_state; - sweep_plan_func=tdvp_sweep_plan, - root_vertex, - reverse_step, - extracter, - extracter_kwargs, - updater, - updater_kwargs, - inserter, - inserter_kwargs, - transform_operator, - transform_operator_kwargs, - time_step, - order, - nsites, - t_evolved, - ) - - return alternating_update( - operator, - init_state, - sweep_plans; - outputlevel, - sweep_observer!, - region_observer!, - sweep_printer, - region_printer, - ) - return state -end diff --git a/test/_test_ttn_contract.jl b/test/_test_ttn_contract.jl deleted file mode 100644 index f3cdbc48..00000000 --- a/test/_test_ttn_contract.jl +++ /dev/null @@ -1,154 +0,0 @@ -@eval module $(gensym()) -using Graphs: vertices -using ITensorNetworks: - ITensorNetworks, - OpSum, - ProjOuterProdTTN, - ProjTTNSum, - ttn, - apply, - contract, - delta, - dmrg, - inner, - mpo, - random_mps, - random_ttn, - siteinds -using ITensorNetworks.ModelHamiltonians: ModelHamiltonians -using ITensors: prime, replaceinds, replaceprime -using LinearAlgebra: norm, normalize -using NamedGraphs.NamedGraphGenerators: named_comb_tree -using StableRNGs: StableRNG -using Test: @test, @test_broken, @testset - -# These tests are broken currently - -@testset "Contract MPO" begin - N = 20 - s = siteinds("S=1/2", N) - rng = StableRNG(1234) - psi = random_mps(rng, s; link_space=8) - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - for j in 1:(N - 2) - os += 0.5, "S+", j, "S-", j + 2 - os += 0.5, "S-", j, "S+", j + 2 - os += "Sz", j, "Sz", j + 2 - end - H = mpo(os, s) - - # Test basic usage with default parameters - Hpsi = apply(H, psi; alg="fit", init=psi, nsweeps=1) - @test inner(psi, Hpsi) ≈ inner(psi', H, psi) rtol = 1e-5 - # Test variational compression via DMRG - Hfit = ProjOuterProdTTN(psi', H) - e, Hpsi_via_dmrg = dmrg(Hfit, psi; updater_kwargs=(; which_eigval=:LR,), nsweeps=1) - @test abs(inner(Hpsi_via_dmrg, Hpsi / norm(Hpsi))) ≈ 1 rtol = 1e-4 - # Test whether the interface works for ProjTTNSum with factors - Hfit = ProjTTNSum([ProjOuterProdTTN(psi', H), ProjOuterProdTTN(psi', H)], [-0.2, -0.8]) - e, Hpsi_via_dmrg = dmrg(Hfit, psi; nsweeps=1, updater_kwargs=(; which_eigval=:SR,)) - @test abs(inner(Hpsi_via_dmrg, Hpsi / norm(Hpsi))) ≈ 1 rtol = 1e-4 - - # Test basic usage for use with multiple ProjOuterProdTTN with default parameters - # BLAS.axpy-like test - os_id = OpSum() - os_id += -1, "Id", 1, "Id", 2 - minus_identity = mpo(os_id, s) - os_id = OpSum() - os_id += +1, "Id", 1, "Id", 2 - identity = mpo(os_id, s) - Hpsi = ITensorNetworks.sum_apply( - [(H, psi), (minus_identity, psi)]; alg="fit", init=psi, nsweeps=3 - ) - @test inner(psi, Hpsi) ≈ (inner(psi', H, psi) - norm(psi)^2) rtol = 1e-5 - # Test the above via DMRG - # ToDo: Investigate why this is broken - Hfit = ProjTTNSum([ProjOuterProdTTN(psi', H), ProjOuterProdTTN(psi', identity)], [-1, 1]) - e, Hpsi_normalized = dmrg(Hfit, psi; nsweeps=3, updater_kwargs=(; which_eigval=:SR)) - @test_broken abs(inner(Hpsi, (Hpsi_normalized) / norm(Hpsi))) ≈ 1 rtol = 1e-5 - - # - # Change "top" indices of MPO to be a different set - # - t = siteinds("S=1/2", N) - psit = deepcopy(psi) - - for j in 1:N - H[j] *= delta(s[j]', t[j]) - psit[j] *= delta(s[j], t[j]) - end - # Test with nsweeps=3 - Hpsi = contract(H, psi; alg="fit", init=psit, nsweeps=3) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-5 - # Test with less good initial guess MPS not equal to psi - psi_guess = truncate(psit; maxdim=2) - Hpsi = contract(H, psi; alg="fit", nsweeps=4, init=psi_guess) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-5 - - # Test with nsite=1 - rng = StableRNG(1234) - Hpsi_guess = random_mps(rng, t; link_space=32) - Hpsi = contract(H, psi; alg="fit", init=Hpsi_guess, nsites=1, nsweeps=4) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-4 -end - -@testset "Contract TTN" begin - tooth_lengths = fill(4, 4) - root_vertex = (1, 4) - c = named_comb_tree(tooth_lengths) - - s = siteinds("S=1/2", c) - rng = StableRNG(1234) - psi = normalize(random_ttn(rng, s; link_space=8)) - - os = ModelHamiltonians.heisenberg(c; J1=1, J2=1) - H = ttn(os, s) - - # Test basic usage with default parameters - Hpsi = apply(H, psi; alg="fit", init=psi, nsweeps=1, cutoff=eps()) - @test inner(psi, Hpsi) ≈ inner(psi', H, psi) rtol = 1e-5 - # Test usage with non-default parameters - Hpsi = apply( - H, psi; alg="fit", init=psi, nsweeps=5, maxdim=[16, 32], cutoff=[1e-4, 1e-8, 1e-12] - ) - @test inner(psi, Hpsi) ≈ inner(psi', H, psi) rtol = 1e-2 - - # Test basic usage for multiple ProjOuterProdTTN with default parameters - # BLAS.axpy-like test - os_id = OpSum() - os_id += -1, "Id", first(vertices(s)), "Id", first(vertices(s)) - minus_identity = ttn(os_id, s) - Hpsi = ITensorNetworks.sum_apply( - [(H, psi), (minus_identity, psi)]; alg="fit", init=psi, nsweeps=1 - ) - @test inner(psi, Hpsi) ≈ (inner(psi', H, psi) - norm(psi)^2) rtol = 1e-5 - - # - # Change "top" indices of TTN to be a different set - # - t = siteinds("S=1/2", c) - psit = deepcopy(psi) - psit = replaceinds(psit, s => t) - H = replaceinds(H, prime(s; links=[]) => t) - - # Test with nsweeps=2 - Hpsi = contract(H, psi; alg="fit", init=psit, nsweeps=2) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-5 - - # Test with less good initial guess MPS not equal to psi - Hpsi_guess = truncate(psit; maxdim=2) - Hpsi = contract(H, psi; alg="fit", nsweeps=4, init=Hpsi_guess) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-5 - - # Test with nsite=1 - rng = StableRNG(1234) - Hpsi_guess = random_ttn(rng, t; link_space=32) - Hpsi = contract(H, psi; alg="fit", nsites=1, nsweeps=10, init=Hpsi_guess) - @test inner(psit, Hpsi) ≈ inner(psit, H, psi) rtol = 1e-2 -end -end diff --git a/test/_test_ttn_dmrg.jl b/test/_test_ttn_dmrg.jl deleted file mode 100644 index b8a8cdb8..00000000 --- a/test/_test_ttn_dmrg.jl +++ /dev/null @@ -1,328 +0,0 @@ -@eval module $(gensym()) -using DataGraphs: edge_data, vertex_data -using Dictionaries: Dictionary -using Graphs: nv, vertices, uniform_tree -using ITensorMPS: ITensorMPS -using ITensorNetworks: - ITensorNetworks, - OpSum, - ttn, - apply, - dmrg, - inner, - mpo, - random_mps, - random_ttn, - linkdims, - siteinds -using ITensorNetworks.ITensorsExtensions: replace_vertices -using ITensorNetworks.ModelHamiltonians: ModelHamiltonians -using ITensors: ITensors -using KrylovKit: eigsolve -using NamedGraphs: NamedGraph, rename_vertices -using NamedGraphs.NamedGraphGenerators: named_comb_tree -using Observers: observer -using StableRNGs: StableRNG -using Suppressor: @capture_out -using Test: @test, @test_broken, @testset - -# This is needed since `eigen` is broken -# if there are no QNs and auto-fermion -# is enabled. -ITensors.disable_auto_fermion() - -@testset "MPS DMRG" for nsites in [1, 2] - N = 10 - cutoff = 1e-12 - - s = siteinds("S=1/2", N) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - rng = StableRNG(1234) - psi = random_mps(rng, s; link_space=20) - - nsweeps = 10 - maxdim = [10, 20, 40, 100] - - # Compare to `ITensors.MPO` version of `dmrg` - H_mpo = ITensorMPS.MPO([H[v] for v in 1:nv(H)]) - psi_mps = ITensorMPS.MPS([psi[v] for v in 1:nv(psi)]) - e2, psi2 = ITensorMPS.dmrg(H_mpo, psi_mps; nsweeps, maxdim, outputlevel=0) - - e, psi = dmrg( - H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) - ) - @test inner(psi', H, psi) ≈ e - @test inner(psi', H, psi) ≈ inner(psi2', H_mpo, psi2) - - # Alias for `ITensorNetworks.dmrg` - e, psi = eigsolve( - H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) - ) - @test inner(psi', H, psi) ≈ e - @test inner(psi', H, psi) ≈ inner(psi2', H_mpo, psi2) - - # Test custom sweep regions #BROKEN, ToDo: Make proper custom sweep regions for test - #= - orig_E = inner(psi', H, psi) - sweep_regions = [[1], [2], [3], [3], [2], [1]] - e, psi = dmrg(H, psi; nsweeps, maxdim, cutoff, sweep_regions) - new_E = inner(psi', H, psi) - @test new_E ≈ orig_E - =# - - # - # Test outputlevels are working - # - prev_output = "" - for outputlevel in 0:2 - output = @capture_out begin - e, psi = dmrg( - H, - psi; - outputlevel, - nsweeps, - maxdim, - cutoff, - nsites, - updater_kwargs=(; krylovdim=3, maxiter=1), - ) - end - if outputlevel == 0 - @test length(output) == 0 - else - @test length(output) > length(prev_output) - end - prev_output = output - end -end - -@testset "Observers" begin - N = 10 - cutoff = 1e-12 - s = siteinds("S=1/2", N) - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - H = mpo(os, s) - rng = StableRNG(1234) - psi = random_mps(rng, s; link_space=20) - - nsweeps = 4 - maxdim = [20, 40, 80, 80] - cutoff = [1e-10] - - # - # Make observers - # - sweep(; which_sweep, kw...) = which_sweep - sweep_observer! = observer(sweep) - - region(; which_region_update, sweep_plan, kw...) = first(sweep_plan[which_region_update]) - energy(; eigvals, kw...) = eigvals[1] - region_observer! = observer(region, sweep, energy) - - e, psi = dmrg(H, psi; nsweeps, maxdim, cutoff, sweep_observer!, region_observer!) - - # - # Test out certain values - # - @test region_observer![9, :region] == [2, 1] - @test region_observer![30, :energy] < -4.25 - @test region_observer![30, :energy] ≈ e rtol = 1e-6 -end - -@testset "Cache to Disk" begin - N = 10 - cutoff = 1e-12 - s = siteinds("S=1/2", N) - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - H = mpo(os, s) - rng = StableRNG(1234) - psi = random_mps(rng, s; link_space=10) - - nsweeps = 4 - maxdim = [10, 20, 40, 80] - - @test_broken e, psi = dmrg( - H, - psi; - nsweeps, - maxdim, - cutoff, - outputlevel=0, - transform_operator=ITensorNetworks.cache_operator_to_disk, - transform_operator_kwargs=(; write_when_maxdim_exceeds=11), - ) -end - -@testset "Regression test: Arrays of Parameters" begin - N = 10 - cutoff = 1e-12 - - s = siteinds("S=1/2", N) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - rng = StableRNG(1234) - psi = random_mps(rng, s; link_space=20) - - # Choose nsweeps to be less than length of arrays - nsweeps = 5 - maxdim = [200, 250, 400, 600, 800, 1200, 2000, 2400, 2600, 3000] - cutoff = [1e-10, 1e-10, 1e-12, 1e-12, 1e-12, 1e-12, 1e-14, 1e-14, 1e-14, 1e-14] - - e, psi = dmrg(H, psi; nsweeps, maxdim, cutoff) -end - -@testset "Tree DMRG" for nsites in [2] - cutoff = 1e-12 - - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) - - @testset "SVD approach" for use_qns in [false, true] - auto_fermion_enabled = ITensors.using_auto_fermion() - if use_qns # test whether autofermion breaks things when using non-fermionic QNs - ITensors.enable_auto_fermion() - else # when using no QNs, autofermion breaks # ToDo reference Issue in ITensors - ITensors.disable_auto_fermion() - end - s = siteinds("S=1/2", c; conserve_qns=use_qns) - - os = ModelHamiltonians.heisenberg(c) - - H = ttn(os, s) - - # make init_state - d = Dict() - for (i, v) in enumerate(vertices(s)) - d[v] = isodd(i) ? "Up" : "Dn" - end - states = v -> d[v] - psi = ttn(states, s) - - # rng = StableRNG(1234) - # psi = random_ttn(rng, s; link_space=20) #FIXME: random_ttn broken for QN conserving case - - nsweeps = 10 - maxdim = [10, 20, 40, 100] - @show use_qns - e, psi = dmrg( - H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) - ) - - # Compare to `ITensors.MPO` version of `dmrg` - linear_order = [4, 1, 2, 5, 3, 6] - vmap = Dictionary(collect(vertices(s))[linear_order], 1:length(linear_order)) - sline = only.(collect(vertex_data(s)))[linear_order] - Hline = ITensorMPS.MPO(replace_vertices(v -> vmap[v], os), sline) - rng = StableRNG(1234) - psiline = ITensorMPS.random_mps(rng, sline, i -> isodd(i) ? "Up" : "Dn"; linkdims=20) - e2, psi2 = ITensorMPS.dmrg(Hline, psiline; nsweeps, maxdim, cutoff, outputlevel=0) - - @test inner(psi', H, psi) ≈ ITensorMPS.inner(psi2', Hline, psi2) atol = 1e-5 - - if !auto_fermion_enabled - ITensors.disable_auto_fermion() - end - end -end - -@testset "Tree DMRG for Fermions" for nsites in [2] #ToDo: change to [1,2] when random_ttn works with QNs - auto_fermion_enabled = ITensors.using_auto_fermion() - use_qns = true - cutoff = 1e-12 - nsweeps = 10 - maxdim = [10, 20, 40, 100] - - # setup model - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) - s = siteinds("Electron", c; conserve_qns=use_qns) - U = 2.0 - t = 1.3 - tp = 0.6 - os = ModelHamiltonians.hubbard(c; U, t, tp) - - # for conversion to ITensors.MPO - linear_order = [4, 1, 2, 5, 3, 6] - vmap = Dictionary(collect(vertices(s))[linear_order], 1:length(linear_order)) - sline = only.(collect(vertex_data(s)))[linear_order] - - # get MPS / MPO with JW string result - ITensors.disable_auto_fermion() - Hline = ITensorMPS.MPO(replace_vertices(v -> vmap[v], os), sline) - rng = StableRNG(1234) - psiline = ITensorMPS.random_mps(rng, sline, i -> isodd(i) ? "Up" : "Dn"; linkdims=20) - e_jw, psi_jw = ITensorMPS.dmrg(Hline, psiline; nsweeps, maxdim, cutoff, outputlevel=0) - ITensors.enable_auto_fermion() - - # now get auto-fermion results - H = ttn(os, s) - # make init_state - d = Dict() - for (i, v) in enumerate(vertices(s)) - d[v] = isodd(i) ? "Up" : "Dn" - end - states = v -> d[v] - psi = ttn(states, s) - e, psi = dmrg( - H, psi; nsweeps, maxdim, cutoff, nsites, updater_kwargs=(; krylovdim=3, maxiter=1) - ) - - # Compare to `ITensors.MPO` version of `dmrg` - Hline = ITensorMPS.MPO(replace_vertices(v -> vmap[v], os), sline) - rng = StableRNG(1234) - psiline = ITensorMPS.random_mps(rng, sline, i -> isodd(i) ? "Up" : "Dn"; linkdims=20) - e2, psi2 = ITensorMPS.dmrg(Hline, psiline; nsweeps, maxdim, cutoff, outputlevel=0) - - @test inner(psi', H, psi) ≈ ITensorMPS.inner(psi2', Hline, psi2) atol = 1e-5 - @test e2 ≈ e_jw atol = 1e-5 - @test inner(psi2', Hline, psi2) ≈ e_jw atol = 1e-5 - - if !auto_fermion_enabled - ITensors.disable_auto_fermion() - end -end - -@testset "Regression test: tree truncation" begin - maxdim = 4 - nsites = 2 - nsweeps = 10 - - rng = StableRNG(1234) - g = NamedGraph(uniform_tree(10)) - g = rename_vertices(v -> (v, 1), g) - s = siteinds("S=1/2", g) - os = ModelHamiltonians.heisenberg(g) - H = ttn(os, s) - psi = random_ttn(rng, s; link_space=5) - e, psi = dmrg(H, psi; nsweeps, maxdim, nsites) - - @test all(edge_data(linkdims(psi)) .<= maxdim) -end -end diff --git a/test/_test_ttn_dmrg_x.jl b/test/_test_ttn_dmrg_x.jl deleted file mode 100644 index 4f2583ac..00000000 --- a/test/_test_ttn_dmrg_x.jl +++ /dev/null @@ -1,70 +0,0 @@ -@eval module $(gensym()) -using Dictionaries: Dictionary -using Graphs: nv, vertices -using ITensorNetworks: - OpSum, ttn, apply, contract, dmrg_x, inner, mpo, mps, random_mps, siteinds -using ITensorNetworks.ModelHamiltonians: ModelHamiltonians -using ITensors: @disable_warn_order, array, dag, onehot, uniqueind -using LinearAlgebra: eigen, normalize -using NamedGraphs.NamedGraphGenerators: named_comb_tree -using StableRNGs: StableRNG -using Test: @test, @testset -# TODO: Combine MPS and TTN tests. -@testset "MPS DMRG-X" for conserve_qns in (false, true) - n = 10 - s = siteinds("S=1/2", n; conserve_qns) - W = 12 - # Random fields h ∈ [-W, W] - rng = StableRNG(1234) - h = W * (2 * rand(rng, n) .- 1) - H = mpo(ModelHamiltonians.heisenberg(n; h), s) - ψ = mps(v -> rand(rng, ["↑", "↓"]), s) - dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0) - e, ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...) - @test inner(ϕ', H, ϕ) / inner(ϕ, ϕ) ≈ e - @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ', H, ϕ) / inner(ϕ, ϕ) rtol = 1e-1 - @test inner(H, ψ, H, ψ) ≉ inner(ψ', H, ψ)^2 rtol = 1e-7 - @test inner(H, ϕ, H, ϕ) ≈ inner(ϕ', H, ϕ)^2 rtol = 1e-7 - e, ϕ̃ = dmrg_x(H, ϕ; nsites=1, dmrg_x_kwargs...) - @test inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) ≈ e - @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) rtol = 1e-1 - @test inner(H, ϕ̃, H, ϕ̃) ≈ inner(ϕ̃', H, ϕ̃)^2 rtol = 1e-3 - # Sometimes broken, sometimes not - # @test abs(loginner(ϕ̃, ϕ) / n) ≈ 0.0 atol = 1e-6 -end -@testset "Tree DMRG-X" for conserve_qns in (false, true) - # TODO: Combine with tests above into a loop over graph structures. - tooth_lengths = fill(2, 3) - root_vertex = (3, 2) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c; conserve_qns) - W = 12 - # Random fields h ∈ [-W, W] - rng = StableRNG(123) - h = Dictionary(vertices(c), W * (2 * rand(rng, nv(c)) .- 1)) - H = ttn(ModelHamiltonians.heisenberg(c; h), s) - ψ = normalize(ttn(v -> rand(rng, ["↑", "↓"]), s)) - dmrg_x_kwargs = (nsweeps=20, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=0) - e, ϕ = dmrg_x(H, ψ; nsites=2, dmrg_x_kwargs...) - @test inner(ϕ', H, ϕ) / inner(ϕ, ϕ) ≈ e - @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ', H, ϕ) / inner(ϕ, ϕ) rtol = 1e-1 - @test inner(H, ψ, H, ψ) ≉ inner(ψ', H, ψ)^2 rtol = 1e-2 - @test inner(H, ϕ, H, ϕ) ≈ inner(ϕ', H, ϕ)^2 rtol = 1e-7 - e, ϕ̃ = dmrg_x(H, ϕ; nsites=1, dmrg_x_kwargs...) - @test inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) ≈ e - @test inner(ψ', H, ψ) / inner(ψ, ψ) ≈ inner(ϕ̃', H, ϕ̃) / inner(ϕ̃, ϕ̃) rtol = 1e-1 - @test inner(H, ϕ̃, H, ϕ̃) ≈ inner(ϕ̃', H, ϕ̃)^2 rtol = 1e-6 - # Sometimes broken, sometimes not - # @test abs(loginner(ϕ̃, ϕ) / nv(c)) ≈ 0.0 atol = 1e-8 - # compare against ED - @disable_warn_order U0 = contract(ψ, root_vertex) - @disable_warn_order T = contract(H, root_vertex) - D, U = eigen(T; ishermitian=true) - u = uniqueind(U, T) - _, max_ind = findmax(abs, array(dag(U0) * U)) - U_exact = U * dag(onehot(u => max_ind)) - @disable_warn_order U_dmrgx = contract(ϕ, root_vertex) - @test inner(ϕ', H, ϕ) ≈ (dag(U_exact') * T * U_exact)[] atol = 1e-6 - @test abs(inner(U_dmrgx, U_exact)) ≈ 1 atol = 1e-6 -end -end diff --git a/test/_test_ttn_linsolve.jl b/test/_test_ttn_linsolve.jl deleted file mode 100644 index dab969ed..00000000 --- a/test/_test_ttn_linsolve.jl +++ /dev/null @@ -1,48 +0,0 @@ -@eval module $(gensym()) -using ITensorNetworks: ITensorNetworks, OpSum, apply, dmrg, inner, mpo, random_mps, siteinds -using KrylovKit: linsolve -using StableRNGs: StableRNG -using Test: @test, @test_broken, @testset - -@testset "Linsolve" begin - @testset "Linsolve Basics" begin - cutoff = 1E-11 - maxdim = 8 - nsweeps = 2 - - N = 8 - # s = siteinds("S=1/2", N; conserve_qns=true) - s = siteinds("S=1/2", N; conserve_qns=false) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - H = mpo(os, s) - - # - # Test complex case - # - - rng = StableRNG(1234) - ## TODO: Need to add support for `random_mps`/`random_tensornetwork` with state input. - ## states = [isodd(n) ? "Up" : "Dn" for n in 1:N] - ## x_c = random_mps(rng, states, s; link_space=4) + 0.1im * random_mps(rng, states, s; link_space=2) - x_c = random_mps(rng, s; link_space=4) + 0.1im * random_mps(rng, s; link_space=2) - - b = apply(H, x_c; alg="fit", nsweeps=3, init=x_c) #cutoff is unsupported kwarg for apply/contract - - ## TODO: Need to add support for `random_mps`/`random_tensornetwork` with state input. - ## x0 = random_mps(rng, states, s; link_space=10) - x0 = random_mps(rng, s; link_space=10) - - x = @test_broken linsolve( - H, b, x0; cutoff, maxdim, nsweeps, updater_kwargs=(; tol=1E-6, ishermitian=true) - ) - - # @test norm(x - x_c) < 1E-3 - end -end -end diff --git a/test/_test_ttn_tdvp.jl b/test/_test_ttn_tdvp.jl deleted file mode 100644 index f4426d21..00000000 --- a/test/_test_ttn_tdvp.jl +++ /dev/null @@ -1,663 +0,0 @@ -@eval module $(gensym()) -using Graphs: dst, edges, src -using ITensors: ITensor, contract, dag, inner, noprime, normalize, prime, scalar -using ITensorNetworks: - ITensorNetworks, - OpSum, - ttn, - apply, - expect, - mpo, - mps, - op, - random_mps, - random_ttn, - siteinds, - tdvp -using ITensorNetworks.ModelHamiltonians: ModelHamiltonians -using LinearAlgebra: norm -using NamedGraphs.NamedGraphGenerators: named_binary_tree, named_comb_tree -using Observers: observer -using StableRNGs: StableRNG -using Test: @testset, @test -@testset "MPS TDVP" begin - @testset "Basic TDVP" begin - N = 10 - cutoff = 1e-12 - - s = siteinds("S=1/2", N) - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - rng = StableRNG(1234) - ψ0 = random_mps(rng, s; link_space=10) - - # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; nsweeps=1, cutoff, nsites=1) - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(inner(ψ1', H, ψ1)) ≈ inner(ψ0', H, ψ0) - - # Time evolve backwards: - ψ2 = tdvp( - H, - +0.1im, - ψ1; - nsweeps=1, - cutoff, - updater_kwargs=(; krylovdim=20, maxiter=20, tol=1e-8), - ) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - - # test different ways to specify time-step specifications - ψa = tdvp(H, -0.1im, ψ0; nsweeps=4, cutoff, nsites=1) - ψb = tdvp(H, -0.1im, ψ0; time_step=-0.025im, cutoff, nsites=1) - ψc = tdvp( - H, -0.1im, ψ0; time_step=[-0.02im, -0.03im, -0.015im, -0.035im], cutoff, nsites=1 - ) - ψd = tdvp( - H, -0.1im, ψ0; nsweeps=4, time_step=[-0.02im, -0.03im, -0.025im], cutoff, nsites=1 - ) - @test inner(ψa, ψb) ≈ 1.0 rtol = 1e-7 - @test inner(ψa, ψc) ≈ 1.0 rtol = 1e-7 - @test inner(ψa, ψd) ≈ 1.0 rtol = 1e-7 - end - - @testset "TDVP: Sum of Hamiltonians" begin - N = 10 - cutoff = 1e-10 - - s = siteinds("S=1/2", N) - - os1 = OpSum() - for j in 1:(N - 1) - os1 += 0.5, "S+", j, "S-", j + 1 - os1 += 0.5, "S-", j, "S+", j + 1 - end - os2 = OpSum() - for j in 1:(N - 1) - os2 += "Sz", j, "Sz", j + 1 - end - - H1 = mpo(os1, s) - H2 = mpo(os2, s) - Hs = [H1, H2] - - rng = StableRNG(1234) - ψ0 = random_mps(rng, s; link_space=10) - - ψ1 = tdvp(Hs, -0.1im, ψ0; nsweeps=1, cutoff, nsites=1) - - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(sum(H -> inner(ψ1', H, ψ1), Hs)) ≈ sum(H -> inner(ψ0', H, ψ0), Hs) - - # Time evolve backwards: - ψ2 = tdvp(Hs, +0.1im, ψ1; nsweeps=1, cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - - @testset "Higher-Order TDVP" begin - N = 10 - cutoff = 1e-12 - order = 4 - - s = siteinds("S=1/2", N) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - rng = StableRNG(1234) - ψ0 = random_mps(rng, s; link_space=10) - - # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; time_step=-0.05im, order, cutoff, nsites=1) - - @test norm(ψ1) ≈ 1.0 - - # Average energy should be conserved: - @test real(inner(ψ1', H, ψ1)) ≈ inner(ψ0', H, ψ0) - - # Time evolve backwards: - ψ2 = tdvp(H, +0.1im, ψ1; time_step=+0.05im, order, cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - - @testset "Accuracy Test" begin - N = 4 - tau = 0.1 - ttotal = 1.0 - cutoff = 1e-12 - - s = siteinds("S=1/2", N; conserve_qns=false) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - H = mpo(os, s) - HM = contract(H) - - Ut = exp(-im * tau * HM) - - state = mps(n -> isodd(n) ? "Up" : "Dn", s) - psi2 = deepcopy(state) - psix = contract(state) - - Sz_tdvp = Float64[] - Sz_tdvp2 = Float64[] - Sz_exact = Float64[] - - c = div(N, 2) - Szc = op("Sz", s[c]) - - Nsteps = Int(ttotal / tau) - for step in 1:Nsteps - psix = noprime(Ut * psix) - psix /= norm(psix) - - state = tdvp( - H, - -im * tau, - state; - cutoff, - normalize=false, - updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), - ) - # TODO: What should `expect` output? Right now - # it outputs a dictionary. - push!(Sz_tdvp, real(expect("Sz", state; vertices=[c])[c])) - - psi2 = tdvp( - H, - -im * tau, - psi2; - cutoff, - normalize=false, - updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), - updater=ITensorNetworks.exponentiate_updater, - ) - # TODO: What should `expect` output? Right now - # it outputs a dictionary. - push!(Sz_tdvp2, real(expect("Sz", psi2; vertices=[c])[c])) - - push!(Sz_exact, real(scalar(dag(prime(psix, s[c])) * Szc * psix))) - F = abs(scalar(dag(psix) * contract(state))) - end - - @test norm(Sz_tdvp - Sz_exact) < 1e-5 - @test norm(Sz_tdvp2 - Sz_exact) < 1e-5 - end - - @testset "TEBD Comparison" begin - N = 10 - cutoff = 1e-12 - tau = 0.1 - ttotal = 1.0 - - s = siteinds("S=1/2", N; conserve_qns=true) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - gates = ITensor[] - for j in 1:(N - 1) - s1 = s[j] - s2 = s[j + 1] - hj = - op("Sz", s1) * op("Sz", s2) + - 1 / 2 * op("S+", s1) * op("S-", s2) + - 1 / 2 * op("S-", s1) * op("S+", s2) - Gj = exp(-1.0im * tau / 2 * hj) - push!(gates, Gj) - end - append!(gates, reverse(gates)) - - state = mps(n -> isodd(n) ? "Up" : "Dn", s) - phi = deepcopy(state) - c = div(N, 2) - - # - # Evolve using TEBD - # - - Nsteps = convert(Int, ceil(abs(ttotal / tau))) - Sz1 = zeros(Nsteps) - En1 = zeros(Nsteps) - #Sz2 = zeros(Nsteps) - #En2 = zeros(Nsteps) - - for step in 1:Nsteps - state = apply(gates, state; cutoff) - - nsites = (step <= 3 ? 2 : 1) - phi = tdvp( - H, - -tau * im, - phi; - nsweeps=1, - cutoff, - nsites, - normalize=true, - updater_kwargs=(; krylovdim=15), - ) - - Sz1[step] = real(expect("Sz", state; vertices=[c])[c]) - #Sz2[step] = real(expect("Sz", phi; vertices=[c])[c]) - En1[step] = real(inner(state', H, state)) - #En2[step] = real(inner(phi', H, phi)) - end - - # - # Evolve using TDVP - # - - phi = mps(n -> isodd(n) ? "Up" : "Dn", s) - - obs = observer( - "Sz" => (; state) -> expect("Sz", state; vertices=[c])[c], - "En" => (; state) -> real(inner(state', H, state)), - ) - - phi = tdvp( - H, - -im * ttotal, - phi; - time_step=-im * tau, - cutoff, - normalize=false, - (sweep_observer!)=obs, - root_vertex=N, # defaults to 1, which breaks observer equality - ) - - Sz2 = obs.Sz - En2 = obs.En - @test norm(Sz1 - Sz2) < 1e-3 - @test norm(En1 - En2) < 1e-3 - end - - @testset "Imaginary Time Evolution" for reverse_step in [true, false] - cutoff = 1e-12 - tau = 1.0 - ttotal = 10.0 - N = 10 - s = siteinds("S=1/2", N) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - - H = mpo(os, s) - - rng = StableRNG(1234) - state = random_mps(rng, s; link_space=2) - en0 = inner(state', H, state) - nsites = [repeat([2], 10); repeat([1], 10)] - maxdim = 32 - state = tdvp( - H, - -ttotal, - state; - time_step=(-tau), - maxdim, - cutoff, - nsites, - reverse_step, - normalize=true, - updater_kwargs=(; krylovdim=15), - ) - en1 = inner(state', H, state) - @test en1 < en0 - end - - @testset "Observers" begin - N = 10 - cutoff = 1e-12 - tau = 0.1 - ttotal = 1.0 - - s = siteinds("S=1/2", N; conserve_qns=true) - - os = OpSum() - for j in 1:(N - 1) - os += 0.5, "S+", j, "S-", j + 1 - os += 0.5, "S-", j, "S+", j + 1 - os += "Sz", j, "Sz", j + 1 - end - H = mpo(os, s) - - c = div(N, 2) - - # - # Using Observers.jl - # - - measure_sz(; state) = expect("Sz", state; vertices=[c])[c] - measure_en(; state) = real(inner(state', H, state)) - sweep_obs = observer("Sz" => measure_sz, "En" => measure_en) - - get_info(; info) = info - step_measure_sz(; state) = expect("Sz", state; vertices=[c])[c] - step_measure_en(; state) = real(inner(state', H, state)) - region_obs = observer( - "Sz" => step_measure_sz, "En" => step_measure_en, "info" => get_info - ) - - state2 = mps(n -> isodd(n) ? "Up" : "Dn", s) - tdvp( - H, - -im * ttotal, - state2; - time_step=-im * tau, - cutoff, - normalize=false, - (sweep_observer!)=sweep_obs, - (region_observer!)=region_obs, - root_vertex=N, # defaults to 1, which breaks observer equality - ) - - Sz2 = sweep_obs.Sz - En2 = sweep_obs.En - - Sz2_step = region_obs.Sz - En2_step = region_obs.En - infos = region_obs.info - - # - # Could use ideas of other things to test here - # - - @test all(x -> x.converged == 1, infos) - end -end - -@testset "Tree TDVP" begin - @testset "Basic TDVP" for c in [named_comb_tree(fill(2, 3)), named_binary_tree(3)] - cutoff = 1e-12 - - tooth_lengths = fill(4, 4) - root_vertex = (1, 4) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os = ModelHamiltonians.heisenberg(c) - - H = ttn(os, s) - - rng = StableRNG(1234) - ψ0 = normalize(random_ttn(rng, s)) - - # Time evolve forward: - ψ1 = tdvp(H, -0.1im, ψ0; root_vertex, nsweeps=1, cutoff, nsites=2) - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(inner(ψ1', H, ψ1)) ≈ inner(ψ0', H, ψ0) - - # Time evolve backwards: - ψ2 = tdvp(H, +0.1im, ψ1; nsweeps=1, cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - - @testset "TDVP: Sum of Hamiltonians" begin - cutoff = 1e-10 - - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os1 = OpSum() - for e in edges(c) - os1 += 0.5, "S+", src(e), "S-", dst(e) - os1 += 0.5, "S-", src(e), "S+", dst(e) - end - os2 = OpSum() - for e in edges(c) - os2 += "Sz", src(e), "Sz", dst(e) - end - - H1 = ttn(os1, s) - H2 = ttn(os2, s) - Hs = [H1, H2] - - rng = StableRNG(1234) - ψ0 = normalize(random_ttn(rng, s; link_space=10)) - - ψ1 = tdvp(Hs, -0.1im, ψ0; nsweeps=1, cutoff, nsites=1) - - @test norm(ψ1) ≈ 1.0 - - ## Should lose fidelity: - #@test abs(inner(ψ0,ψ1)) < 0.9 - - # Average energy should be conserved: - @test real(sum(H -> inner(ψ1', H, ψ1), Hs)) ≈ sum(H -> inner(ψ0', H, ψ0), Hs) - - # Time evolve backwards: - ψ2 = tdvp(Hs, +0.1im, ψ1; nsweeps=1, cutoff) - - @test norm(ψ2) ≈ 1.0 - - # Should rotate back to original state: - @test abs(inner(ψ0, ψ2)) > 0.99 - end - - @testset "Accuracy Test" begin - tau = 0.1 - ttotal = 1.0 - cutoff = 1e-12 - - tooth_lengths = fill(2, 3) - root_vertex = (3, 2) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) - HM = contract(H) - - Ut = exp(-im * tau * HM) - - state = ttn(ComplexF64, v -> iseven(sum(isodd.(v))) ? "Up" : "Dn", s) - statex = contract(state) - - Sz_tdvp = Float64[] - Sz_exact = Float64[] - - c = (2, 1) - Szc = op("Sz", s[c]) - - Nsteps = Int(ttotal / tau) - for step in 1:Nsteps - statex = noprime(Ut * statex) - statex /= norm(statex) - - state = tdvp( - H, - -im * tau, - state; - cutoff, - normalize=false, - updater_kwargs=(; tol=1e-12, maxiter=500, krylovdim=25), - ) - push!(Sz_tdvp, real(expect("Sz", state; vertices=[c])[c])) - push!(Sz_exact, real(scalar(dag(prime(statex, s[c])) * Szc * statex))) - F = abs(scalar(dag(statex) * contract(state))) - end - - @test norm(Sz_tdvp - Sz_exact) < 1e-5 - end - - # TODO: apply gates in ITensorNetworks - - @testset "TEBD Comparison" begin - cutoff = 1e-12 - maxdim = typemax(Int) - tau = 0.1 - ttotal = 1.0 - - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) - - gates = ITensor[] - for e in edges(c) - s1 = s[src(e)] - s2 = s[dst(e)] - hj = - op("Sz", s1) * op("Sz", s2) + - 1 / 2 * op("S+", s1) * op("S-", s2) + - 1 / 2 * op("S-", s1) * op("S+", s2) - Gj = exp(-1.0im * tau / 2 * hj) - push!(gates, Gj) - end - append!(gates, reverse(gates)) - - state = ttn(v -> iseven(sum(isodd.(v))) ? "Up" : "Dn", s) - phi = copy(state) - c = (2, 1) - - # - # Evolve using TEBD - # - - Nsteps = convert(Int, ceil(abs(ttotal / tau))) - Sz1 = zeros(Nsteps) - En1 = zeros(Nsteps) - Sz2 = zeros(Nsteps) - En2 = zeros(Nsteps) - - for step in 1:Nsteps - state = apply(gates, state; cutoff, maxdim) - - nsites = (step <= 3 ? 2 : 1) - phi = tdvp( - H, - -tau * im, - phi; - nsweeps=1, - cutoff, - nsites, - normalize=true, - updater_kwargs=(; krylovdim=15), - ) - - Sz1[step] = real(expect("Sz", state; vertices=[c])[c]) - Sz2[step] = real(expect("Sz", phi; vertices=[c])[c]) - En1[step] = real(inner(state', H, state)) - En2[step] = real(inner(phi', H, phi)) - end - - # - # Evolve using TDVP - # - - phi = ttn(v -> iseven(sum(isodd.(v))) ? "Up" : "Dn", s) - obs = observer( - "Sz" => (; state) -> expect("Sz", state; vertices=[c])[c], - "En" => (; state) -> real(inner(state', H, state)), - ) - phi = tdvp( - H, - -im * ttotal, - phi; - time_step=-im * tau, - cutoff, - normalize=false, - (sweep_observer!)=obs, - root_vertex=(3, 2), - ) - - @test norm(Sz1 - Sz2) < 5e-3 - @test norm(En1 - En2) < 5e-3 - @test abs.(last(Sz1) - last(obs.Sz)) .< 5e-3 - @test abs.(last(Sz2) - last(obs.Sz)) .< 5e-3 - end - - @testset "Imaginary Time Evolution" for reverse_step in [true, false] - cutoff = 1e-12 - tau = 1.0 - ttotal = 50.0 - - tooth_lengths = fill(2, 3) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - os = ModelHamiltonians.heisenberg(c) - H = ttn(os, s) - - rng = StableRNG(1234) - state = normalize(random_ttn(rng, s; link_space=2)) - - trange = 0.0:tau:ttotal - for (step, t) in enumerate(trange) - nsites = (step <= 10 ? 2 : 1) - state = tdvp( - H, - -tau, - state; - cutoff, - nsites, - reverse_step, - normalize=true, - updater_kwargs=(; krylovdim=15), - ) - end - - @test inner(state', H, state) < -2.47 - end -end -end diff --git a/test/_test_ttn_tdvp_time_dependent.jl b/test/_test_ttn_tdvp_time_dependent.jl deleted file mode 100644 index 4101bc83..00000000 --- a/test/_test_ttn_tdvp_time_dependent.jl +++ /dev/null @@ -1,236 +0,0 @@ -@eval module $(gensym()) -using ITensorNetworks: ITensorNetworks, TimeDependentSum, ttn, mpo, mps, siteinds, tdvp -using ITensorNetworks.ModelHamiltonians: ModelHamiltonians -using ITensors: contract -using KrylovKit: exponentiate -using LinearAlgebra: norm -using NamedGraphs: AbstractNamedEdge -using NamedGraphs.NamedGraphGenerators: named_comb_tree -using OrdinaryDiffEqTsit5: Tsit5 -using Test: @test, @test_broken, @testset - -include( - joinpath( - @__DIR__, "ITensorNetworksTestSolversUtils", "ITensorNetworksTestSolversUtils.jl" - ), -) - -using .ITensorNetworksTestSolversUtils: - ITensorNetworksTestSolversUtils, krylov_solver, ode_solver - -# Functions need to be defined in global scope (outside -# of the @testset macro) - -ω₁ = 0.1 -ω₂ = 0.2 - -ode_alg = Tsit5() -ode_kwargs = (; reltol=1e-8, abstol=1e-8) - -ω⃗ = [ω₁, ω₂] -f⃗ = [t -> cos(ω * t) for ω in ω⃗] -ode_updater_kwargs = (; f=[f⃗], solver_alg=ode_alg, ode_kwargs) - -function ode_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, - ode_kwargs, - solver_alg, - f, -) - region = first(sweep_plan[which_region_update]) - (; time_step, t) = internal_kwargs - t = isa(region, AbstractNamedEdge) ? t : t + time_step - - H⃗₀ = projected_operator![] - result, info = ode_solver( - -im * TimeDependentSum(f, H⃗₀), - time_step, - init; - current_time=t, - solver_alg, - ode_kwargs..., - ) - return result, (; info) -end - -function tdvp_ode_solver(H⃗₀, ψ₀; time_step, kwargs...) - psi_t, info = ode_solver( - -im * TimeDependentSum(f⃗, H⃗₀), time_step, ψ₀; solver_alg=ode_alg, ode_kwargs... - ) - return psi_t, (; info) -end - -krylov_kwargs = (; tol=1e-8, krylovdim=15, eager=true) -krylov_updater_kwargs = (; f=[f⃗], krylov_kwargs) - -function ITensorNetworksTestSolversUtils.krylov_solver( - H⃗₀, ψ₀; time_step, ishermitian=false, issymmetric=false, kwargs... -) - psi_t, info = krylov_solver( - -im * TimeDependentSum(f⃗, H⃗₀), - time_step, - ψ₀; - krylov_kwargs..., - ishermitian, - issymmetric, - ) - return psi_t, (; info) -end - -function krylov_updater( - init; - state!, - projected_operator!, - outputlevel, - which_sweep, - sweep_plan, - which_region_update, - internal_kwargs, - ishermitian=false, - issymmetric=false, - f, - krylov_kwargs, -) - (; time_step, t) = internal_kwargs - H⃗₀ = projected_operator![] - region = first(sweep_plan[which_region_update]) - t = isa(region, AbstractNamedEdge) ? t : t + time_step - - result, info = krylov_solver( - -im * TimeDependentSum(f, H⃗₀), - time_step, - init; - current_time=t, - krylov_kwargs..., - ishermitian, - issymmetric, - ) - return result, (; info) -end - -@testset "MPS: Time dependent Hamiltonian" begin - n = 4 - J₁ = 1.0 - J₂ = 0.1 - - time_step = 0.1 - time_total = 1.0 - - nsites = 2 - maxdim = 100 - cutoff = 1e-8 - - s = siteinds("S=1/2", n) - ℋ₁₀ = ModelHamiltonians.heisenberg(n; J1=J₁, J2=0.0) - ℋ₂₀ = ModelHamiltonians.heisenberg(n; J1=0.0, J2=J₂) - ℋ⃗₀ = [ℋ₁₀, ℋ₂₀] - H⃗₀ = [mpo(ℋ₀, s) for ℋ₀ in ℋ⃗₀] - - ψ₀ = complex(mps(j -> isodd(j) ? "↑" : "↓", s)) - - ψₜ_ode = tdvp( - H⃗₀, - time_total, - ψ₀; - time_step, - maxdim, - cutoff, - nsites, - updater=ode_updater, - updater_kwargs=ode_updater_kwargs, - ) - - ψₜ_krylov = tdvp( - H⃗₀, - time_total, - ψ₀; - time_step, - cutoff, - nsites, - updater=krylov_updater, - updater_kwargs=krylov_updater_kwargs, - ) - - ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total) - - @test norm(ψ₀) ≈ 1 - @test norm(ψₜ_ode) ≈ 1 - @test norm(ψₜ_krylov) ≈ 1 - @test norm(ψₜ_full) ≈ 1 - - ode_err = norm(contract(ψₜ_ode) - ψₜ_full) - krylov_err = norm(contract(ψₜ_krylov) - ψₜ_full) - #ToDo: Investigate why Krylov gives better result than ODE solver - @test_broken krylov_err > ode_err - @test ode_err < 1e-2 - @test krylov_err < 1e-2 -end - -@testset "TTN: Time dependent Hamiltonian" begin - tooth_lengths = fill(2, 3) - root_vertex = (3, 2) - c = named_comb_tree(tooth_lengths) - s = siteinds("S=1/2", c) - - J₁ = 1.0 - J₂ = 0.1 - - time_step = 0.1 - time_total = 1.0 - - nsites = 2 - maxdim = 100 - cutoff = 1e-8 - - s = siteinds("S=1/2", c) - ℋ₁₀ = ModelHamiltonians.heisenberg(c; J1=J₁, J2=0.0) - ℋ₂₀ = ModelHamiltonians.heisenberg(c; J1=0.0, J2=J₂) - ℋ⃗₀ = [ℋ₁₀, ℋ₂₀] - H⃗₀ = [ttn(ℋ₀, s) for ℋ₀ in ℋ⃗₀] - - ψ₀ = ttn(ComplexF64, v -> iseven(sum(isodd.(v))) ? "↑" : "↓", s) - - ψₜ_ode = tdvp( - H⃗₀, - time_total, - ψ₀; - time_step, - maxdim, - cutoff, - nsites, - updater=ode_updater, - updater_kwargs=ode_updater_kwargs, - ) - - ψₜ_krylov = tdvp( - H⃗₀, - time_total, - ψ₀; - time_step, - cutoff, - nsites, - updater=krylov_updater, - updater_kwargs=krylov_updater_kwargs, - ) - ψₜ_full, _ = tdvp_ode_solver(contract.(H⃗₀), contract(ψ₀); time_step=time_total) - - @test norm(ψ₀) ≈ 1 - @test norm(ψₜ_ode) ≈ 1 - @test norm(ψₜ_krylov) ≈ 1 - @test norm(ψₜ_full) ≈ 1 - - ode_err = norm(contract(ψₜ_ode) - ψₜ_full) - krylov_err = norm(contract(ψₜ_krylov) - ψₜ_full) - #ToDo: Investigate why Krylov gives better result than ODE solver - @test_broken krylov_err > ode_err - @test ode_err < 1e-2 - @test krylov_err < 1e-2 -end -end From 60235bc0328787be229d75f1a961bc54da0b523f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 16 Oct 2025 16:43:31 -0400 Subject: [PATCH 52/55] Method `subspace_expand!(::Backend"densitymatrix")` now defines kwarg defaults for the function it calls --- src/solvers/eigsolve.jl | 14 ++++++++++++-- src/solvers/subspace/densitymatrix.jl | 16 +++++++++++----- src/solvers/subspace/subspace.jl | 5 ++--- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index 996979a4..c9ab5a22 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -59,11 +59,21 @@ function default_sweep_callback( end end -function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_kwargs...) +function eigsolve( + operator, init_state; nsweeps, nsites=1, outputlevel=0, factorize_kwargs, sweep_kwargs... +) init_prob = EigsolveProblem(; state=align_indices(init_state), operator=ProjTTN(align_indices(operator)) ) - sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...) + sweep_iter = SweepIterator( + init_prob, + nsweeps; + nsites, + outputlevel, + factorize_kwargs, + subspace_expand!_kwargs=(; eigen_kwargs=factorize_kwargs), + sweep_kwargs..., + ) prob = problem(sweep_solve!(sweep_iter)) return eigenvalue(prob), state(prob) end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 94e047ab..4eaf5753 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -2,7 +2,13 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf @define_default_kwargs function subspace_expand!( - ::Backend"densitymatrix", region_iter, local_state; north_pass=1 + ::Backend"densitymatrix", + region_iter, + local_state; + expansion_factor=1.5, + maxexpand=typemax(Int), + north_pass=1, + eigen_kwargs=(;), ) prob = problem(region_iter) @@ -24,7 +30,7 @@ using Printf: @printf basis_size = prod(dim.(uniqueinds(A, C))) expanded_maxdim = compute_expansion( - dim(a), basis_size; region_kwargs(compute_expansion, region_iter)... + dim(a), basis_size; expansion_factor, maxexpand, eigen_kwargs.maxdim ) expanded_maxdim <= 0 && return region_iter, local_state @@ -38,14 +44,14 @@ using Printf: @printf sqrt_rho *= H[prev_vertex] conj_proj_A(T) = (T - prime(A) * (dag(prime(A)) * T)) - for pass in 1:north_pass + for _ in 1:north_pass sqrt_rho = conj_proj_A(sqrt_rho) end rho = sqrt_rho * dag(noprime(sqrt_rho)) - D, U = eigen(rho; region_kwargs(eigen, region_iter)..., ishermitian=true) + D, U = eigen(rho; eigen_kwargs..., ishermitian=true) Uproj(T) = (T - prime(A, a) * (dag(prime(A, a)) * T)) - for pass in 1:north_pass + for _ in 1:north_pass U = Uproj(U) end if norm(dag(U) * A) > 1E-10 diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index c6177eb3..ebaea7b3 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -24,9 +24,8 @@ function subspace_expand!(backend, region_iterator, local_state; kwargs...) ) end -function compute_expansion( - current_dim, basis_size; expansion_factor=1.5, maxexpand=typemax(Int), maxdim=typemax(Int) -) +# Have these defaults set per backend in `subspace_expand!` +function compute_expansion(current_dim, basis_size; expansion_factor, maxexpand, maxdim) # Note: expand_maxdim will be *added* to current bond dimension # Obtain expand_maxdim from expansion_factor expand_maxdim = ceil(Int, expansion_factor * current_dim) From da3ad278937d0491bf80efb9a26a4a73677b5125 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 16 Oct 2025 16:49:15 -0400 Subject: [PATCH 53/55] Solvers code now no longer relies on `default_kwargs` system To be reintroduced in a later date. --- src/solvers/applyexp.jl | 2 +- src/solvers/eigsolve.jl | 2 +- src/solvers/fitting.jl | 2 +- src/solvers/iterators.jl | 8 +++----- src/solvers/subspace/densitymatrix.jl | 2 +- src/solvers/subspace/subspace.jl | 6 ++---- 6 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/solvers/applyexp.jl b/src/solvers/applyexp.jl index 9c080cfb..3b91ff8e 100644 --- a/src/solvers/applyexp.jl +++ b/src/solvers/applyexp.jl @@ -20,7 +20,7 @@ function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...) return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...) end -@define_default_kwargs function update!( +function update!( region_iter::RegionIterator{<:ApplyExpProblem}, local_state; nsites, diff --git a/src/solvers/eigsolve.jl b/src/solvers/eigsolve.jl index c9ab5a22..907e71f5 100644 --- a/src/solvers/eigsolve.jl +++ b/src/solvers/eigsolve.jl @@ -20,7 +20,7 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing) return E end -@define_default_kwargs function update!( +function update!( region_iter::RegionIterator{<:EigsolveProblem}, local_state; outputlevel=0, diff --git a/src/solvers/fitting.jl b/src/solvers/fitting.jl index d0dfddb8..c03852e6 100644 --- a/src/solvers/fitting.jl +++ b/src/solvers/fitting.jl @@ -44,7 +44,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem}) return region_iter, local_tensor end -@define_default_kwargs function update!( +function update!( region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel=0 ) F = problem(region_iter) diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index b7347bb9..16497f0e 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -100,11 +100,9 @@ function increment!(region_iter::RegionIterator) end function compute!(iter::RegionIterator) - _, local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...) - _, local_state = @with_defaults update!( - iter, local_state; region_kwargs(update!, iter)... - ) - @with_defaults insert!(iter, local_state; region_kwargs(insert!, iter)...) + _, local_state = extract!(iter; region_kwargs(extract!, iter)...) + _, local_state = update!(iter, local_state; region_kwargs(update!, iter)...) + insert!(iter, local_state; region_kwargs(insert!, iter)...) return iter end diff --git a/src/solvers/subspace/densitymatrix.jl b/src/solvers/subspace/densitymatrix.jl index 4eaf5753..ae2ff507 100644 --- a/src/solvers/subspace/densitymatrix.jl +++ b/src/solvers/subspace/densitymatrix.jl @@ -1,7 +1,7 @@ using NamedGraphs.GraphsExtensions: incident_edges using Printf: @printf -@define_default_kwargs function subspace_expand!( +function subspace_expand!( ::Backend"densitymatrix", region_iter, local_state; diff --git a/src/solvers/subspace/subspace.jl b/src/solvers/subspace/subspace.jl index ebaea7b3..d5388245 100644 --- a/src/solvers/subspace/subspace.jl +++ b/src/solvers/subspace/subspace.jl @@ -1,16 +1,14 @@ using NDTensors: NDTensors using NDTensors.BackendSelection: Backend, @Backend_str -@define_default_kwargs function subspace_expand!( - region_iter, local_state; subspace_algorithm="nothing" -) +function subspace_expand!(region_iter, local_state; subspace_algorithm="nothing") backend = Backend(subspace_algorithm) if backend isa Backend"nothing" return region_iter, local_state end - _, local_state = @with_defaults subspace_expand!( + _, local_state = subspace_expand!( backend, region_iter, local_state; region_kwargs(subspace_expand!, region_iter)... ) From 8725370d9882480074b9cdf2b6890383ce4b4f9d Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 16 Oct 2025 16:50:35 -0400 Subject: [PATCH 54/55] Remove `default_kwargs` related to source files --- src/solvers/default_kwargs.jl | 115 ---------------------------- test/solvers/test_default_kwargs.jl | 47 ------------ 2 files changed, 162 deletions(-) delete mode 100644 src/solvers/default_kwargs.jl delete mode 100644 test/solvers/test_default_kwargs.jl diff --git a/src/solvers/default_kwargs.jl b/src/solvers/default_kwargs.jl deleted file mode 100644 index c58925ad..00000000 --- a/src/solvers/default_kwargs.jl +++ /dev/null @@ -1,115 +0,0 @@ -using MacroTools: @capture, splitdef, combinedef, isdef - -""" - default_kwargs(f::Function, args...; kwargs...) - -Returns a set of default keyword arguments, as a `NamedTuple`, for the function `f` -depending on an arbitrary number of positional arguments. Any number of these default -keyword arguments can optionally be overwritten by passing the the keyword as a -keyword argument to this function. -""" -function default_kwargs(f::Function, args...; kwargs...) - return default_kwargs(f, map(typeof, args)...; kwargs...) -end -default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...) - -""" - @define_default_kwargs - -Automatically define a `default_kwargs` method for a given function. This macro should -be applied before a function definition: -``` -@define_default_kwargs function f(arg1::T1, arg2::T2, ...; kwargs...) - ... -end -``` -The defined `default_kwargs` method takes the form -``` -default_kwargs(::typeof(f), arg1::T1, arg2::T2, ...; kwargs...) -``` -i.e. the function signature mirrors that of the function signature of `f`. -""" -macro define_default_kwargs(function_def) - return default_kwargs_macro(function_def) -end - -function default_kwargs_macro(function_def) - if !isdef(function_def) - throw( - ArgumentError( - "The @define_default_kwargs macro must be followed by a function definition" - ), - ) - end - - ex = splitdef(function_def) - new_ex = deepcopy(ex) - - prev_kwargs = [] - - # Give very positional argument a name and escape the type. - ex[:args] = map(ex[:args]) do arg - @capture(arg, (name_::T_) | (::T_) | name_) - if isnothing(name) - name = gensym() - end - if isnothing(T) - T = :Any - end - return :($(name)::$(esc(T))) - end - - # Replacing the kwargs values with the output of `default_kwargs` - ex[:kwargs] = map(ex[:kwargs]) do kw - @capture(kw, (key_::T_ = val_) | (key_ = val_) | key_) - if !isnothing(val) - kw.args[2] = - :(default_kwargs($(esc(ex[:name])), $(ex[:args]...); $(prev_kwargs...)).$key) - end - push!(prev_kwargs, key) - return kw - end - - new_ex[:args] = convert(Vector{Any}, ex[:args]) - - new_ex[:name] = :(ITensorNetworks.default_kwargs) - new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name]))))) - - # Escape anything on the right-hand side of a keyword definition. - new_ex[:kwargs] = map(new_ex[:kwargs]) do kw - @capture(kw, (key_ = val_) | key_) - if !isnothing(val) - kw.args[2] = esc(val) - end - return kw - end - - new_ex[:body] = :(return (; $(prev_kwargs...))) - - # Escape the actual function name - ex[:name] = :($(esc(ex[:name]))) - - rv = quote - $(combinedef(ex)) - $(combinedef(new_ex)) - end - - return rv -end - -macro with_defaults(call_expr) - if @capture(call_expr, (func_(args__; kwargs__)) | (func_(args__))) - if isnothing(kwargs) - kwargs = [] - end - rv = quote - $(esc(func))( - $(esc.(args)...); - default_kwargs($(esc(func)), $(esc.(args)...); $(esc.(kwargs)...))..., - ) - end - return rv - else - throw(ArgumentError("unable to parse function call expression, try including brackets in the macro call.")) - end -end diff --git a/test/solvers/test_default_kwargs.jl b/test/solvers/test_default_kwargs.jl deleted file mode 100644 index 010b0edd..00000000 --- a/test/solvers/test_default_kwargs.jl +++ /dev/null @@ -1,47 +0,0 @@ -using Test: @test, @testset -using ITensorNetworks: AbstractProblem, default_kwargs, RegionIterator, problem, region_kwargs, @with_defaults - -module KwargsTestModule - -using ITensorNetworks -using ITensorNetworks: AbstractProblem, @define_default_kwargs - -struct TestProblem <: AbstractProblem end -struct NotOurTestProblem <: AbstractProblem end - -@define_default_kwargs function test_function(::AbstractProblem; bool=false, int=3) - return bool, int -end -@define_default_kwargs function test_function(::TestProblem; bool=true, int=0) - return bool, int -end - -end # KwargsTestModule - -@testset "Default kwargs" begin - import .KwargsTestModule - - our_iter = RegionIterator(KwargsTestModule.TestProblem(), ["region" => (; test_function_kwargs=(; int=1))], 1) - not_our_iter = RegionIterator(KwargsTestModule.NotOurTestProblem(), ["region" => (; test_function_kwargs=(; int=2))], 1) - - kw = region_kwargs(KwargsTestModule.test_function, our_iter) - @test kw == (; int=1) - kw_not = region_kwargs(KwargsTestModule.test_function, not_our_iter) - @test kw_not == (; int=2) - - # Test dispatch - @test default_kwargs(KwargsTestModule.test_function, problem(our_iter)) == (; bool=true, int=0) - - @test default_kwargs(KwargsTestModule.test_function, problem(not_our_iter)) == (; bool=false, int=3) - - @test KwargsTestModule.test_function(problem(our_iter); default_kwargs(KwargsTestModule.test_function, problem(our_iter); kw...)...) == (true, 1) - @test KwargsTestModule.test_function(problem(not_our_iter); default_kwargs(KwargsTestModule.test_function, problem(not_our_iter); kw_not...)...) == (false, 2) - - @test @with_defaults(KwargsTestModule.test_function(problem(our_iter))) == (true, 0) - @test @with_defaults(KwargsTestModule.test_function(problem(our_iter);)) == (true, 0) - @test @with_defaults(KwargsTestModule.test_function(problem(our_iter); bool = false)) == (false, 0) - - let testval = @with_defaults KwargsTestModule.test_function(problem(our_iter); int = 3) - @test testval == (true, 3) - end -end From 8afce8ab8c656610c08c0a3fe5353e51e5e75187 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Sat, 18 Oct 2025 19:06:39 -0400 Subject: [PATCH 55/55] Delete stale include --- src/ITensorNetworks.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index ec63ee28..339e900e 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -45,7 +45,6 @@ include("treetensornetworks/projttns/projttn.jl") include("treetensornetworks/projttns/projttnsum.jl") include("treetensornetworks/projttns/projouterprodttn.jl") -include("solvers/default_kwargs.jl") include("solvers/local_solvers/eigsolve.jl") include("solvers/local_solvers/exponentiate.jl") include("solvers/local_solvers/runge_kutta.jl")