Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7a51966
Add `Problem` as type parameter to `SweepIterator`
jack-dunham Sep 22, 2025
245182b
Format test files and improve comparisons for readabilty on failure
jack-dunham Sep 24, 2025
7af4b25
Redesign iterator interface by introducing AbstractNetworkIterator ab…
jack-dunham Sep 24, 2025
c0ae5d0
Add `EachRegion` adapter that wraps `RegionIterator`, behaving the sa…
jack-dunham Sep 25, 2025
3b9d0af
Add unit tests for the `AbstractNetworkIterator` interface
jack-dunham Sep 30, 2025
4ef4e75
Rename `done` to `laststep` to better reflect the when it evalutes to…
jack-dunham Sep 30, 2025
e112eb4
Rename `previous_region` to `prev_region` to better align with julia …
jack-dunham Sep 30, 2025
da360e0
Rename `PauseAfterIncrement` -> `NoComputeStep` and improve some vari…
jack-dunham Oct 1, 2025
8bfc483
Make `extract` and `subspace_expand` mutating
jack-dunham Oct 3, 2025
1ef8498
Make `update` mutable
jack-dunham Oct 3, 2025
0a6e891
Make `insert` mutable
jack-dunham Oct 3, 2025
0653c47
First implementation of an `options` system.
jack-dunham Oct 3, 2025
d77321e
Simplify options interface to a single function `default_kwargs`.
jack-dunham Oct 6, 2025
aff14c7
Put calls to `extract!` etc in `compute!` function directly
jack-dunham Oct 6, 2025
4b21cc9
Refactor the region plan generating code.
jack-dunham Oct 7, 2025
e71512f
Have `dmrg` take a strict number of arguments
jack-dunham Oct 7, 2025
a4ce308
Purge non-mutating field setter functions.
jack-dunham Oct 7, 2025
a8b2c51
Use `current_kwargs` for getting kwargs from `RegionIterator`
jack-dunham Oct 7, 2025
18a8503
Introduce defaults using `default_kwargs` and be stricter about which…
jack-dunham Oct 7, 2025
0c9022c
Swap order of local_state and region_iter args
jack-dunham Oct 7, 2025
a9be11e
Add some unit tests for the defaults
jack-dunham Oct 7, 2025
4d52088
Rename file options.jl -> test_default_kwargs.jl
jack-dunham Oct 7, 2025
613d533
Fix `euler_sweep` returning kwargs not as `NamedTuple`
jack-dunham Oct 7, 2025
20bf783
The `sweep_solve` callbacks now get called without any keyword argume…
jack-dunham Oct 7, 2025
568c631
Some minor refactoring of the iterators.
jack-dunham Oct 7, 2025
fed9137
The `EachRegion` adapter now flattens the nested Sweep/Region iterato…
jack-dunham Oct 9, 2025
4ce453e
Add tests for `EachRegion` and `eachregion` wrapper functions
jack-dunham Oct 9, 2025
c59a9c5
Rename `laststep` -> `islaststep` in fitting with Julia conventions.
jack-dunham Oct 9, 2025
62195b6
Overhaul `default_kwargs` such that it mirrors the function signature…
jack-dunham Oct 9, 2025
917f2f1
Rename `NoComputeStep` to `IncrementOnly`
jack-dunham Oct 9, 2025
112d55e
Remove @info statement and fix bug with `astypes` not promoting corre…
jack-dunham Oct 10, 2025
0a9f127
Update `default_kwargs` tests.
jack-dunham Oct 10, 2025
e35f325
Remove stray `end` from `adapters.jl`.
jack-dunham Oct 14, 2025
6a8cdb1
Fix typo in docstring of `EachRegion` adapter.
jack-dunham Oct 14, 2025
9760de1
Function `reverse_regions` is now more concise.
jack-dunham Oct 14, 2025
26ece7b
Use explicit imports in `default_kwargs.jl`
jack-dunham Oct 14, 2025
340d805
Fix test imports and broken tests in `test_iterators.jl`.
jack-dunham Oct 14, 2025
f89c379
Merge branch 'network_solvers' of https://github.com/jack-dunham/ITen…
jack-dunham Oct 14, 2025
6a33f29
Rename @default_kwargs -> @define_default_kwargs
jack-dunham Oct 14, 2025
b4bcb93
Remove `astypes` option from `@define_default_kwargs`.
jack-dunham Oct 14, 2025
624f964
Update `default_kwargs` tests.
jack-dunham Oct 14, 2025
bd35f09
Add `sweep_solve` method for `EachRegion` adapter.
jack-dunham Oct 14, 2025
0b5314d
Add `@with_kwargs` macro which automatically splats `default_kwargs` …
jack-dunham Oct 14, 2025
a58ec92
Make use of `@with_kwargs` macro make code more concise.
jack-dunham Oct 14, 2025
b72a08f
The fallback default callback functions now no longer accept `kwargs.…
jack-dunham Oct 15, 2025
c5de5c4
Test fix: tests founds in sub-directories are now actually ran when i…
jack-dunham Oct 15, 2025
2788057
Skip broken tests for now
jack-dunham Oct 15, 2025
33b9e28
Rename `sweep_solve` -> `sweep_solve!` to obey convention
jack-dunham Oct 15, 2025
dedd82e
The `EachRegion` adapter now returns itself from `iterate` instead of…
jack-dunham Oct 15, 2025
d39f09e
The `sweep_solve!` function now always returns the type of the input …
jack-dunham Oct 15, 2025
3f5c97c
Mutating functions now return the first argument before any additiona…
jack-dunham Oct 15, 2025
7ad3138
Remove depreciated `solvers` code and tests from old interface
jack-dunham Oct 16, 2025
60235bc
Method `subspace_expand!(::Backend"densitymatrix")` now defines kwarg…
jack-dunham Oct 16, 2025
da3ad27
Solvers code now no longer relies on `default_kwargs` system
jack-dunham Oct 16, 2025
8725370
Remove `default_kwargs` related to source files
jack-dunham Oct 16, 2025
8afce8a
Delete stale include
mtfishman Oct 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/solvers/abstract_problem.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

abstract type AbstractProblem end

set_truncation_info(P::AbstractProblem, args...; kws...) = P
set_truncation_info!(P::AbstractProblem, args...; kws...) = P
62 changes: 38 additions & 24 deletions src/solvers/adapters.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
"""
struct PauseAfterIncrement{S<:AbstractNetworkIterator}

#
# TupleRegionIterator
#
# Adapts outputs to be (region, region_kwargs) tuples
#
# More generic design? maybe just assuming RegionIterator
# or its outputs implement some interface function that
# generates each tuple?
#

mutable struct TupleRegionIterator{RegionIter}
region_iterator::RegionIter
Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the
process. This allows one to manually call a custom `compute!` or insert their own code it in
the loop body in place of `compute!`.
"""
struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator
parent::S
end

region_iterator(T::TupleRegionIterator) = T.region_iterator
islaststep(adapter::IncrementOnly) = islaststep(adapter.parent)
state(adapter::IncrementOnly) = state(adapter.parent)
increment!(adapter::IncrementOnly) = increment!(adapter.parent)
compute!(adapter::IncrementOnly) = adapter

function Base.iterate(T::TupleRegionIterator, which=1)
state = iterate(region_iterator(T), which)
isnothing(state) && return nothing
(current_region, region_kwargs) = current_region_plan(region_iterator(T))
return (current_region, region_kwargs), last(state)
end
IncrementOnly(adapter::IncrementOnly) = adapter

"""
region_tuples(R::RegionIterator)
struct EachRegion{SweepIterator} <: AbstractNetworkIterator

The `region_tuples` adapter converts a RegionIterator into an
iterator which outputs a tuple of the form (current_region, current_region_kwargs)
at each step.
Adapter that flattens each region iterator in the parent sweep iterator into a single
iterator.
"""
region_tuples(R::RegionIterator) = TupleRegionIterator(R)
struct EachRegion{SI<:SweepIterator} <: AbstractNetworkIterator
parent::SI
end

# In keeping with Julia convention.
eachregion(iter::SweepIterator) = EachRegion(iter)

# Essential definitions
function islaststep(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
return islaststep(adapter.parent) && islaststep(region_iter)
end
function increment!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter)
return adapter
end
function compute!(adapter::EachRegion)
region_iter = region_iterator(adapter.parent)
compute!(region_iter)
return adapter
end
100 changes: 51 additions & 49 deletions src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Printf: @printf
using Accessors: @set

@kwdef mutable struct ApplyExpProblem{State} <: AbstractProblem
operator
Expand All @@ -11,66 +10,69 @@ operator(A::ApplyExpProblem) = A.operator
state(A::ApplyExpProblem) = A.state
current_exponent(A::ApplyExpProblem) = A.current_exponent
function current_time(A::ApplyExpProblem)
t = im*A.current_exponent
t = im * A.current_exponent
return iszero(imag(t)) ? real(t) : t
end

set_operator(A::ApplyExpProblem, operator) = (@set A.operator = operator)
set_state(A::ApplyExpProblem, state) = (@set A.state = state)
set_current_exponent(A::ApplyExpProblem, exponent) = (@set A.current_exponent = exponent)

function region_plan(A::ApplyExpProblem; nsites, time_step, sweep_kwargs...)
return applyexp_regions(state(A), time_step; nsites, sweep_kwargs...)
# Rename region_plan
function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...)
# The `exponent_step` kwarg for the `update!` function needs some pre-processing.
return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...)
end

function update(
prob::ApplyExpProblem,
local_state,
region_iterator;
function update!(
region_iter::RegionIterator{<:ApplyExpProblem},
local_state;
nsites,
exponent_step,
solver=runge_kutta_solver,
outputlevel,
kws...,
)
iszero(abs(exponent_step)) && return prob, local_state
prob = problem(region_iter)

if iszero(abs(exponent_step))
return region_iter, local_state
end

local_state, info = solver(
x->optimal_map(operator(prob), x), exponent_step, local_state; kws...
solver_kwargs = region_kwargs(solver, region_iter)

local_state, _ = solver(
x -> optimal_map(operator(prob), x), exponent_step, local_state; solver_kwargs...
)
if nsites==1
curr_reg = current_region(region_iterator)
next_reg = next_region(region_iterator)
if nsites == 1
curr_reg = current_region(region_iter)
next_reg = next_region(region_iter)
if !isnothing(next_reg) && next_reg != curr_reg
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
v1, v2 = src(next_edge), dst(next_edge)
psi = copy(state(prob))
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
shifted_operator = position(operator(prob), psi, NamedEdge(v1=>v2))
R_t, _ = solver(x->optimal_map(shifted_operator, x), -exponent_step, R; kws...)
local_state = psi[v1]*R_t
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
R_t, _ = solver(
x -> optimal_map(shifted_operator, x), -exponent_step, R; solver_kwargs...
)
local_state = psi[v1] * R_t
end
end

prob = set_current_exponent(prob, current_exponent(prob)+exponent_step)
prob.current_exponent += exponent_step

return prob, local_state
return region_iter, local_state
end

function sweep_callback(
problem::ApplyExpProblem;
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description="exponent",
outputlevel,
sweep,
nsweeps,
outputlevel=0,
process_time=identity,
kws...,
)
if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
" Current %s = %s, ", exponent_description, process_time(current_exponent(problem))
" Current %s = %s, ",
exponent_description,
process_time(current_exponent(the_problem))
)
@printf("maxlinkdim=%d", maxlinkdim(state(problem)))
@printf("maxlinkdim=%d", maxlinkdim(state(the_problem)))
println()
flush(stdout)
end
Expand All @@ -79,19 +81,20 @@ end
function applyexp(
init_prob::AbstractProblem,
exponents;
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
outputlevel=0,
nsites=1,
sweep_callback=default_sweep_callback,
order=4,
kws...,
nsites=2,
sweep_kwargs...,
)
exponent_steps = diff([zero(eltype(exponents)); exponents])
sweep_kws = (; outputlevel, extract_kwargs, insert_kwargs, nsites, order, update_kwargs)
kws_array = [(; sweep_kws..., time_step=t) for t in exponent_steps]
sweep_iter = sweep_iterator(init_prob, kws_array)
converged_prob = sweep_solve(sweep_iter; outputlevel, kws...)

kws_array = [
(; order, nsites, sweep_kwargs..., exponent_step) for exponent_step in exponent_steps
]
sweep_iter = SweepIterator(init_prob, kws_array)

converged_prob = problem(sweep_solve!(sweep_callback, sweep_iter))

return state(converged_prob)
end

Expand All @@ -111,11 +114,10 @@ function time_evolve(
time_points,
init_state;
process_time=process_real_times,
sweep_callback=(
a...; k...
)->sweep_callback(a...; exponent_description="time", process_time, k...),
kws...,
sweep_callback=iter ->
default_sweep_callback(iter; exponent_description="time", process_time),
sweep_kwargs...,
)
exponents = [-im*t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, kws...)
exponents = [-im * t for t in time_points]
return applyexp(operator, exponents, init_state; sweep_callback, sweep_kwargs...)
end
72 changes: 35 additions & 37 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Accessors: @set
using Printf: @printf
using ITensors: truncerror

Expand All @@ -14,42 +13,43 @@ state(E::EigsolveProblem) = E.state
operator(E::EigsolveProblem) = E.operator
max_truncerror(E::EigsolveProblem) = E.max_truncerror

set_operator(E::EigsolveProblem, operator) = (@set E.operator = operator)
set_eigenvalue(E::EigsolveProblem, eigenvalue) = (@set E.eigenvalue = eigenvalue)
set_state(E::EigsolveProblem, state) = (@set E.state = state)
set_max_truncerror(E::EigsolveProblem, truncerror) = (@set E.max_truncerror = truncerror)

function set_truncation_info(E::EigsolveProblem; spectrum=nothing)
function set_truncation_info!(E::EigsolveProblem; spectrum=nothing)
if !isnothing(spectrum)
E = set_max_truncerror(E, max(max_truncerror(E), truncerror(spectrum)))
E.max_truncerror = max(max_truncerror(E), truncerror(spectrum))
end
return E
end

function update(
prob::EigsolveProblem,
local_state,
region_iterator;
outputlevel,
function update!(
region_iter::RegionIterator{<:EigsolveProblem},
local_state;
outputlevel=0,
solver=eigsolve_solver,
kws...,
)
eigval, local_state = solver(ψ->optimal_map(operator(prob), ψ), local_state; kws...)
prob = set_eigenvalue(prob, eigval)
prob = problem(region_iter)

eigval, local_state = solver(
ψ -> optimal_map(operator(prob), ψ), local_state; region_kwargs(solver, region_iter)...
)

prob.eigenvalue = eigval

if outputlevel >= 2
@printf(
" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob)
)
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
end
return prob, local_state
return region_iter, local_state
end

function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, kws...)
function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0
)
if outputlevel >= 1
if nsweeps >= 10
@printf("After sweep %02d/%d ", sweep, nsweeps)
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
if length(sweep_iterator) >= 10
@printf("After sweep %02d/%d ", current_sweep, nsweeps)
else
@printf("After sweep %d/%d ", sweep, nsweeps)
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
Expand All @@ -60,24 +60,22 @@ function sweep_callback(problem::EigsolveProblem; outputlevel, sweep, nsweeps, k
end

function eigsolve(
operator,
init_state;
nsweeps,
nsites=1,
outputlevel=0,
extract_kwargs=(;),
update_kwargs=(;),
insert_kwargs=(;),
kws...,
operator, init_state; nsweeps, nsites=1, outputlevel=0, factorize_kwargs, sweep_kwargs...
)
init_prob = EigsolveProblem(;
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
)
sweep_iter = sweep_iterator(
init_prob, nsweeps; nsites, outputlevel, extract_kwargs, update_kwargs, insert_kwargs
sweep_iter = SweepIterator(
init_prob,
nsweeps;
nsites,
outputlevel,
factorize_kwargs,
subspace_expand!_kwargs=(; eigen_kwargs=factorize_kwargs),
sweep_kwargs...,
)
prob = sweep_solve(sweep_iter; outputlevel, kws...)
prob = problem(sweep_solve!(sweep_iter))
return eigenvalue(prob), state(prob)
end

dmrg(args...; kws...) = eigsolve(args...; kws...)
dmrg(operator, init_state; kwargs...) = eigsolve(operator, init_state; kwargs...)
24 changes: 14 additions & 10 deletions src/solvers/extract.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
function extract(problem, region_iterator; sweep, trunc=(;), kws...)
trunc = truncation_parameters(sweep; trunc...)
region = current_region(region_iterator)
psi = orthogonalize(state(problem), region)
function extract!(region_iter::RegionIterator; subspace_algorithm="nothing")
prob = problem(region_iter)
region = current_region(region_iter)

psi = orthogonalize(state(prob), region)
local_state = prod(psi[v] for v in region)
problem = set_state(problem, psi)
problem, local_state = subspace_expand(
problem, local_state, region_iterator; sweep, trunc, kws...
)
shifted_operator = position(operator(problem), state(problem), region)
return set_operator(problem, shifted_operator), local_state

prob.state = psi

_, local_state = subspace_expand!(region_iter, local_state; subspace_algorithm)
shifted_operator = position(operator(prob), state(prob), region)

prob.operator = shifted_operator

return region_iter, local_state
end
Loading
Loading