diff --git a/Project.toml b/Project.toml index f31d793..35bb240 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.6.11" +version = "0.6.12" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -16,10 +16,12 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" [extensions] TensorAlgebraGPUArraysCoreExt = "GPUArraysCore" +TensorAlgebraMooncakeExt = "Mooncake" TensorAlgebraTensorOperationsExt = "TensorOperations" [compat] @@ -30,6 +32,7 @@ FunctionImplementations = "0.3.1, 0.4" GPUArraysCore = "0.2.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6" +Mooncake = "0.4.202" StridedViews = "0.4.1" TensorOperations = "5" TupleTools = "1.6" diff --git a/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl b/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl new file mode 100644 index 0000000..adca47e --- /dev/null +++ b/ext/TensorAlgebraMooncakeExt/TensorAlgebraMooncakeExt.jl @@ -0,0 +1,25 @@ +module TensorAlgebraMooncakeExt + +using Mooncake: Mooncake, @zero_derivative, DefaultCtx +using TensorAlgebra: AbstractBlockPermutation, ContractAlgorithm, allocate_output, + biperm, blockedperms, check_input, contract, contract!, contract_labels, + default_contract_algorithm, select_contract_algorithm + +Mooncake.tangent_type(::Type{<:AbstractBlockPermutation}) = Mooncake.NoTangent +Mooncake.tangent_type(::Type{<:ContractAlgorithm}) = Mooncake.NoTangent + +@zero_derivative DefaultCtx Tuple{ + typeof(allocate_output), typeof(contract), Any, Any, Any, Any, Any, +} +@zero_derivative DefaultCtx Tuple{typeof(biperm), Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(blockedperms), typeof(contract), Any, Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(check_input), typeof(contract), Any, Any, Any, Any} +@zero_derivative DefaultCtx Tuple{ + typeof(check_input), typeof(contract!), Any, Any, Any, Any, Any, Any, +} +@zero_derivative DefaultCtx Tuple{typeof(contract_labels), Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(contract_labels), Any, Any, Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(default_contract_algorithm), Any, Any} +@zero_derivative DefaultCtx Tuple{typeof(select_contract_algorithm), Any, Any, Any} + +end diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 1a2c74b..3a95afc 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -28,6 +28,10 @@ abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockL widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple +# Otherwise it will convert to a BlockTuple since the default `Base.deepcopy` implementation +# calls `map`. +Base.deepcopy(bp::AbstractBlockPermutation) = bp + # Block a permutation based on the specified lengths. # blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1)) # TODO: Optimize with StaticNumbers.jl or generated functions, see: diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 76baafa..e4e3a7b 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -16,12 +16,6 @@ length_domain(t) = 0 length_codomain(t) = length(t) - length_domain(t) -function blockedperms( - f::typeof(contract), alg::ContractAlgorithm, dimnames_dest, dimnames1, dimnames2 - ) - return blockedperms(f, dimnames_dest, dimnames1, dimnames2) -end - # codomain <-- domain function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2))) diff --git a/src/contract/contract.jl b/src/contract/contract.jl index c5a0c22..fb1db8e 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -39,7 +39,9 @@ function contract!( a2::AbstractArray, labels2; kwargs..., ) - return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...) + return contractadd!( + a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs... + ) end # contractadd! diff --git a/test/Project.toml b/test/Project.toml index 0a2aff5..242880d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/test/test_mooncakeext.jl b/test/test_mooncakeext.jl new file mode 100644 index 0000000..4547a64 --- /dev/null +++ b/test/test_mooncakeext.jl @@ -0,0 +1,113 @@ +import Mooncake +import Random +using TensorAlgebra: AbstractBlockPermutation, BlockedPermutation, ContractAlgorithm, + DefaultContractAlgorithm, Matricize, allocate_output, biperm, blockedperms, check_input, + contract, contract!, contract_labels, contractadd!, default_contract_algorithm, + permmortar, select_contract_algorithm +using Test: @test, @testset + +@testset "MooncakeExt" begin + elt = Float64 + mode = Mooncake.ReverseMode + rng = Random.default_rng() + is_primitive = false + atol = eps(real(elt))^(3 / 4) + rtol = eps(real(elt))^(3 / 4) + @testset "zero derivatives" begin + @test Mooncake.tangent_type(AbstractBlockPermutation) ≡ Mooncake.NoTangent + @test Mooncake.tangent_type(BlockedPermutation) ≡ Mooncake.NoTangent + @test Mooncake.tangent_type(ContractAlgorithm) ≡ Mooncake.NoTangent + @test Mooncake.tangent_type(DefaultContractAlgorithm) ≡ Mooncake.NoTangent + @test Mooncake.tangent_type(Matricize) ≡ Mooncake.NoTangent + + dest = randn(elt, (2, 2)) + a1 = randn(elt, (2, 2)) + a2 = randn(elt, (2, 2)) + biperm_dest = permmortar(((1,), (2,))) + biperm1 = permmortar(((1,), (2,))) + biperm2 = permmortar(((1,), (2,))) + labels_dest = (:i, :k) + labels1 = (:i, :j) + labels2 = (:j, :k) + + Mooncake.TestUtils.test_rule( + rng, allocate_output, contract, biperm_dest, a1, biperm1, a2, biperm2; + mode, is_primitive, + ) + Mooncake.TestUtils.test_rule(rng, biperm, (1, 2, 3), Val(2); mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, biperm, (1, 2, 3), 2; mode, is_primitive) + Mooncake.TestUtils.test_rule( + rng, blockedperms, contract, labels_dest, labels1, labels2; mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, check_input, contract, a1, biperm1, a2, biperm2; mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, check_input, contract!, dest, biperm_dest, a1, biperm1, a2, biperm2; + mode, is_primitive, + ) + Mooncake.TestUtils.test_rule( + rng, contract_labels, labels1, labels2; mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, contract_labels, a1, labels1, a2, labels2; mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, default_contract_algorithm, a1, a2; mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, select_contract_algorithm, DefaultContractAlgorithm(), a1, a2; + mode, is_primitive, + ) + end + @testset "contract" begin + α = true + β = false + @testset "contractadd! (BlockedPermutation)" begin + dest = randn(elt, (2, 2)) + a1 = randn(elt, (2, 2)) + a2 = randn(elt, (2, 2)) + biperm_dest = permmortar(((1,), (2,))) + biperm1 = permmortar(((1,), (2,))) + biperm2 = permmortar(((1,), (2,))) + Mooncake.TestUtils.test_rule( + rng, contractadd!, dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; + atol, rtol, mode, is_primitive, + ) + end + @testset "contractadd! (labels)" begin + dest = randn(elt, (2, 2)) + a1 = randn(elt, (2, 2)) + a2 = randn(elt, (2, 2)) + labels_dest = (:i, :k) + labels1 = (:i, :j) + labels2 = (:j, :k) + Mooncake.TestUtils.test_rule( + rng, contractadd!, dest, labels_dest, a1, labels1, a2, labels2, α, β; + atol, rtol, mode, is_primitive, + ) + end + @testset "contract! (labels)" begin + dest = randn(elt, (2, 2)) + a1 = randn(elt, (2, 2)) + a2 = randn(elt, (2, 2)) + labels_dest = (:i, :k) + labels1 = (:i, :j) + labels2 = (:j, :k) + Mooncake.TestUtils.test_rule( + rng, contract!, dest, labels_dest, a1, labels1, a2, labels2; + atol, rtol, mode, is_primitive, + ) + end + @testset "contract (labels)" begin + a1 = randn(elt, (2, 2)) + a2 = randn(elt, (2, 2)) + labels1 = (:i, :j) + labels2 = (:j, :k) + Mooncake.TestUtils.test_rule( + rng, contract, a1, labels1, a2, labels2; + atol, rtol, mode, is_primitive, + ) + end + end +end