diff --git a/Janus.Analyzers/Models/UnionTypeAttribute.Model.cs b/Janus.Analyzers/Models/UnionTypeAttribute.Model.cs index 66250d0..cf82d14 100644 --- a/Janus.Analyzers/Models/UnionTypeAttribute.Model.cs +++ b/Janus.Analyzers/Models/UnionTypeAttribute.Model.cs @@ -92,7 +92,7 @@ private void Initialize(ITypeSymbol variant, CancellationToken ct) Type = new( isArray ? VariantTypeKind.Reference - : actualVariant.IsUnmanagedType + : actualVariant.IsUnmanagedOrUnmanagedUnionType ? VariantTypeKind.Unmanaged : VariantTypeKind.Value, IsNullable: true, @@ -106,7 +106,7 @@ private void Initialize(ITypeSymbol variant, CancellationToken ct) var name = variant.ToDisplayString(TypeDisplayFormat); Type = extractedVariant switch { - { IsUnmanagedType: true } => + { IsUnmanagedOrUnmanagedUnionType: true } => new(isArray ? VariantTypeKind.Reference : VariantTypeKind.Unmanaged, @@ -172,3 +172,85 @@ public override Boolean IsNullable set => base.IsNullable = value; } } + +file static class Extensions +{ + extension(ITypeSymbol symbol) + { + public Boolean IsUnmanagedOrUnmanagedUnionType + { + get + { + var result = isUnmanaged(symbol, null); + + return result; + + Boolean isUnmanaged(ITypeSymbol candidate, Dictionary? unmanagedTypes) + { + if (!candidate.IsUnmanagedType) + { + unmanagedTypes?[candidate] = false; + return false; + } + + unmanagedTypes = new Dictionary(SymbolEqualityComparer.Default); + + if (unmanagedTypes.TryGetValue(candidate, out var memoizedValue)) + { + if (memoizedValue is { } memoizedResult) + { + return memoizedResult; + } + + return true; + } + + var typeAttributes = candidate.GetAttributes(); + foreach (var attribute in typeAttributes) + { + if (!attribute.IsUnionTypeAttribute()) + { + continue; + } + + foreach (var variant in attribute.AttributeClass?.TypeArguments ?? []) + { + if (isUnmanaged(variant, unmanagedTypes)) + { + continue; + } + + unmanagedTypes[candidate] = false; + return false; + } + } + + if (candidate is INamedTypeSymbol namedCandidate) + { + foreach (var typeParameter in namedCandidate.TypeParameters) + { + if (typeParameter.HasUnmanagedTypeConstraint) + { + continue; + } + + var typeParameterAttributes = typeParameter.GetAttributes(); + + foreach (var attribute in typeParameterAttributes) + { + if (attribute.IsUnionTypeAttribute()) + { + unmanagedTypes[candidate] = false; + return false; + } + } + } + } + + unmanagedTypes[candidate] = true; + return true; + } + } + } + } +} diff --git a/Janus.Tests.EndToEnd/UnionTypeVariantTest.cs b/Janus.Tests.EndToEnd/UnionTypeVariantTest.cs new file mode 100644 index 0000000..f263b89 --- /dev/null +++ b/Janus.Tests.EndToEnd/UnionTypeVariantTest.cs @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MPL-2.0 + +namespace RhoMicro.CodeAnalysis.Janus.EndToEnd.Tests; + +public partial class UnionTypeVariantTests +{ + [UnionType] + partial struct Text; + + [UnionType] + partial struct Union; + + [Fact] + public void ManagedUnionTypeStructVariantDoesNotThrowTle() + { + Text text = "foo"; + Union union = text; + Assert.Equal(text, union.CastToText); + } +}