diff --git a/src/ErrorProne.NET.Core/CompilationExtensions.cs b/src/ErrorProne.NET.Core/CompilationExtensions.cs
index 58279f2..b87df87 100644
--- a/src/ErrorProne.NET.Core/CompilationExtensions.cs
+++ b/src/ErrorProne.NET.Core/CompilationExtensions.cs
@@ -61,6 +61,20 @@ public static class OperationExtensions
}
}
+ public sealed class TaskTypesInfo
+ {
+ public INamedTypeSymbol? TaskSymbol { get; }
+ public INamedTypeSymbol? TaskOfTSymbol { get; }
+ public INamedTypeSymbol? ValueTaskOfTSymbol { get; }
+
+ public TaskTypesInfo(Compilation compilation)
+ {
+ TaskSymbol = compilation.GetTypeByMetadataName(typeof(Task).FullName);
+ TaskOfTSymbol = compilation.GetTypeByMetadataName(typeof(Task<>).FullName);
+ ValueTaskOfTSymbol = compilation.GetTypeByMetadataName(typeof(ValueTask<>).FullName);
+ }
+ }
+
// Copied from internal ICompilationExtensions class from the roslyn codebase
public static class CompilationExtensions
{
@@ -124,30 +138,29 @@ public static (INamedTypeSymbol? taskType, INamedTypeSymbol? taskOfTType, INamed
return (taskType, taskOfTType, valueTaskOfTType);
}
- public static bool IsTaskLike(this ITypeSymbol? returnType, Compilation compilation)
+ public static bool IsTaskLike(this ITypeSymbol? returnType, TaskTypesInfo info)
{
if (returnType == null)
{
return false;
}
- var (taskType, taskOfTType, valueTaskOfTType) = GetTaskTypes(compilation);
- if (taskType == null || taskOfTType == null)
+ if (info.TaskSymbol == null || info.TaskOfTSymbol == null)
{
return false; // ?
}
- if (returnType.Equals(taskType, SymbolEqualityComparer.Default))
+ if (returnType.Equals(info.TaskSymbol, SymbolEqualityComparer.Default))
{
return true;
}
- if (returnType.OriginalDefinition.Equals(taskOfTType, SymbolEqualityComparer.Default))
+ if (returnType.OriginalDefinition.Equals(info.TaskOfTSymbol, SymbolEqualityComparer.Default))
{
return true;
}
- if (returnType.OriginalDefinition.Equals(valueTaskOfTType, SymbolEqualityComparer.Default))
+ if (returnType.OriginalDefinition.Equals(info.ValueTaskOfTSymbol, SymbolEqualityComparer.Default))
{
return true;
}
diff --git a/src/ErrorProne.NET.Core/DiagnosticAnalyzerBase.cs b/src/ErrorProne.NET.Core/DiagnosticAnalyzerBase.cs
index 5c6b6b4..dcc6869 100644
--- a/src/ErrorProne.NET.Core/DiagnosticAnalyzerBase.cs
+++ b/src/ErrorProne.NET.Core/DiagnosticAnalyzerBase.cs
@@ -33,11 +33,10 @@ protected DiagnosticAnalyzerBase(params DiagnosticDescriptor[] diagnostics)
public sealed override void Initialize(AnalysisContext context)
{
context.EnableConcurrentExecution();
-
- if (ReportDiagnosticsOnGeneratedCode)
- {
- context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
- }
+
+ var flags = ReportDiagnosticsOnGeneratedCode ? GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics : GeneratedCodeAnalysisFlags.None;
+
+ context.ConfigureGeneratedCodeAnalysis(flags);
InitializeCore(context);
}
diff --git a/src/ErrorProne.NET.Core/NamedSymbolExtensions.cs b/src/ErrorProne.NET.Core/NamedSymbolExtensions.cs
index c297eac..8be10ff 100644
--- a/src/ErrorProne.NET.Core/NamedSymbolExtensions.cs
+++ b/src/ErrorProne.NET.Core/NamedSymbolExtensions.cs
@@ -84,37 +84,5 @@ private static string BuildQualifiedAssemblyName(string? nameSpace, string typeN
return $"{symbolType}, {new AssemblyName(assemblySymbol.Identity.GetDisplayName(true))}";
}
-
- public static bool IsDerivedFromInterface(this INamedTypeSymbol namedType, Type type)
- {
- Contract.Requires(namedType != null);
- Contract.Requires(type != null);
-
- return Enumerable.Any(namedType.AllInterfaces, symbol => symbol.IsType(type));
- }
-
- public static bool IsExceptionType(this ISymbol? symbol, SemanticModel model)
- {
- if (!(symbol is INamedTypeSymbol namedSymbol))
- {
- return false;
- }
-
- var exceptionType = model.Compilation.GetTypeByFullName(typeof(Exception).FullName);
-
- return TraverseTypeAndItsBaseTypes(namedSymbol).Any(x => x.Equals(exceptionType, SymbolEqualityComparer.Default));
- }
-
- public static bool IsArgumentExceptionType(this ISymbol? symbol, SemanticModel model)
- {
- if (!(symbol is INamedTypeSymbol namedSymbol))
- {
- return false;
- }
-
- var exceptionType = model.Compilation.GetTypeByFullName(typeof(ArgumentException).FullName);
-
- return TraverseTypeAndItsBaseTypes(namedSymbol).Any(x => x.Equals(exceptionType, SymbolEqualityComparer.Default));
- }
}
}
\ No newline at end of file
diff --git a/src/ErrorProne.NET.Core/SymbolExtensions.cs b/src/ErrorProne.NET.Core/SymbolExtensions.cs
index a7bdb7e..0d4faab 100644
--- a/src/ErrorProne.NET.Core/SymbolExtensions.cs
+++ b/src/ErrorProne.NET.Core/SymbolExtensions.cs
@@ -135,23 +135,6 @@ public static Location GetParameterLocation(this IParameterSymbol parameter)
return parameter.Locations[0];
}
- ///
- /// Returns true if a given is .
- ///
- public static bool IsConfigureAwait(this IMethodSymbol method, Compilation compilation)
- {
- // Naive implementation
- return method.Name == "ConfigureAwait" && method.ReceiverType.IsTaskLike(compilation);
- }
-
- ///
- /// Returns true if a given is .
- ///
- public static bool IsContinueWith(this IMethodSymbol method, Compilation compilation)
- {
- return method.Name == "ContinueWith" && method.ReceiverType.IsTaskLike(compilation) && method.ReturnType.IsTaskLike(compilation);
- }
-
///
/// Returns true if a given has iterator block inside of it.
///
@@ -168,7 +151,7 @@ public static bool IsIteratorBlock(this IMethodSymbol method)
///
/// Returns true if the given is async or return task-like type.
///
- public static bool IsAsyncOrTaskBased(this IMethodSymbol method, Compilation compilation)
+ public static bool IsAsyncOrTaskBased(this IMethodSymbol method, TaskTypesInfo info)
{
// Currently method detects only Task or ValueTask
if (method.IsAsync)
@@ -176,7 +159,7 @@ public static bool IsAsyncOrTaskBased(this IMethodSymbol method, Compilation com
return true;
}
- return method.ReturnType.IsTaskLike(compilation);
+ return method.ReturnType.IsTaskLike(info);
}
///
diff --git a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitConfiguration.cs b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitConfiguration.cs
index 078522b..a1ecfec 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitConfiguration.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitConfiguration.cs
@@ -28,13 +28,4 @@ public enum ConfigureAwait
UseConfigureAwaitFalse,
DoNotUseConfigureAwait,
}
-
- internal static class TempExtensions
- {
- public static bool IsConfigureAwait(this IMethodSymbol method, Compilation compilation)
- {
- // Naive implementation
- return method.Name == "ConfigureAwait" && method.ReceiverType.IsTaskLike(compilation);
- }
- }
}
\ No newline at end of file
diff --git a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitRequiredAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitRequiredAnalyzer.cs
index 95d7bb0..f7cd0fa 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitRequiredAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/ConfigureAwaitRequiredAnalyzer.cs
@@ -1,11 +1,10 @@
-using System.Runtime.CompilerServices;
+using System.Collections.Immutable;
+using System.Runtime.CompilerServices;
+using System.Threading.Tasks;
using ErrorProne.NET.CoreAnalyzers;
using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
-using CompilationExtensions = ErrorProne.NET.Core.CompilationExtensions;
namespace ErrorProne.NET.AsyncAnalyzers
{
@@ -27,38 +26,56 @@ public ConfigureAwaitRequiredAnalyzer()
///
protected override void InitializeCore(AnalysisContext context)
{
- context.RegisterSyntaxNodeAction(AnalyzeAwaitExpression, SyntaxKind.AwaitExpression);
+ context.RegisterCompilationStartAction(context =>
+ {
+ var compilation = context.Compilation;
+
+ var configureAwaitConfig = ConfigureAwaitConfiguration.TryGetConfigureAwait(context.Compilation);
+ if (configureAwaitConfig != ConfigureAwait.UseConfigureAwaitFalse)
+ {
+ return;
+ }
+
+ var taskType = compilation.GetTypeByMetadataName(typeof(Task).FullName);
+ if (taskType is null)
+ {
+ return;
+ }
+
+ var configureAwaitMethods = taskType.GetMembers("ConfigureAwait").OfType().ToImmutableArray();
+ if (configureAwaitMethods.IsEmpty)
+ {
+ return;
+ }
+
+ var yieldAwaitable = compilation.GetTypeByMetadataName(typeof(YieldAwaitable).FullName);
+
+ context.RegisterOperationAction(context => AnalyzeAwaitOperation(context, configureAwaitMethods, yieldAwaitable), OperationKind.Await);
+ });
+
}
- private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
+ private static void AnalyzeAwaitOperation(OperationAnalysisContext context, ImmutableArray configureAwaitMethods, INamedTypeSymbol? yieldAwaitable)
{
- var invocation = (AwaitExpressionSyntax)context.Node;
+ var awaitOperation = (IAwaitOperation)context.Operation;
- var configureAwaitConfig = ConfigureAwaitConfiguration.TryGetConfigureAwait(context.Compilation);
- if (configureAwaitConfig == ConfigureAwait.UseConfigureAwaitFalse)
+ if (awaitOperation.Operation is IInvocationOperation configureAwaitOperation)
{
- var operation = context.SemanticModel.GetOperation(invocation, context.CancellationToken);
- if (operation is IAwaitOperation awaitOperation)
+ if (configureAwaitMethods.Contains(configureAwaitOperation.TargetMethod))
{
- if (awaitOperation.Operation is IInvocationOperation configureAwaitOperation)
- {
- if (configureAwaitOperation.TargetMethod.IsConfigureAwait(context.Compilation))
- {
- return;
- }
-
- if (CompilationExtensions.IsClrType(configureAwaitOperation.Type, context.Compilation, typeof(YieldAwaitable)))
- {
- return;
- }
- }
-
- var location = awaitOperation.Syntax.GetLocation();
+ return;
+ }
- var diagnostic = Diagnostic.Create(Rule, location);
- context.ReportDiagnostic(diagnostic);
+ if (SymbolEqualityComparer.Default.Equals(configureAwaitOperation.Type, yieldAwaitable))
+ {
+ return;
}
}
+
+ var location = awaitOperation.Syntax.GetLocation();
+
+ var diagnostic = Diagnostic.Create(Rule, location);
+ context.ReportDiagnostic(diagnostic);
}
}
}
\ No newline at end of file
diff --git a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/RedundantConfigureAwaitFalseAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/RedundantConfigureAwaitFalseAnalyzer.cs
index 2afa5aa..bc6c31b 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/RedundantConfigureAwaitFalseAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/RedundantConfigureAwaitFalseAnalyzer.cs
@@ -1,4 +1,6 @@
-using System.Diagnostics.ContractsLight;
+using System.Collections.Immutable;
+using System.Diagnostics.ContractsLight;
+using System.Threading.Tasks;
using ErrorProne.NET.CoreAnalyzers;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
@@ -27,45 +29,63 @@ public RedundantConfigureAwaitFalseAnalyzer()
///
protected override void InitializeCore(AnalysisContext context)
{
- context.RegisterSyntaxNodeAction(AnalyzeAwaitExpression, SyntaxKind.AwaitExpression);
+ context.RegisterCompilationStartAction(context =>
+ {
+ var compilation = context.Compilation;
+
+ var configureAwaitConfig = ConfigureAwaitConfiguration.TryGetConfigureAwait(context.Compilation);
+ if (configureAwaitConfig != ConfigureAwait.DoNotUseConfigureAwait)
+ {
+ return;
+ }
+
+ var taskType = compilation.GetTypeByMetadataName(typeof(Task).FullName);
+ if (taskType is null)
+ {
+ return;
+ }
+
+ var configureAwaitMethods = taskType.GetMembers("ConfigureAwait").OfType().ToImmutableArray();
+ if (configureAwaitMethods.IsEmpty)
+ {
+ return;
+ }
+
+ context.RegisterOperationAction(context => AnalyzeAwaitOperation(context, configureAwaitMethods), OperationKind.Await);
+ });
+
}
- private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
+ private static void AnalyzeAwaitOperation(OperationAnalysisContext context, ImmutableArray configureAwaitMethods)
{
- var invocation = (AwaitExpressionSyntax)context.Node;
+ var awaitOperation = (IAwaitOperation)context.Operation;
- var configureAwaitConfig = ConfigureAwaitConfiguration.TryGetConfigureAwait(context.Compilation);
- if (configureAwaitConfig == ConfigureAwait.DoNotUseConfigureAwait)
+ if (awaitOperation.Operation is IInvocationOperation configureAwaitOperation &&
+ configureAwaitMethods.Contains(configureAwaitOperation.TargetMethod))
{
- var operation = context.SemanticModel.GetOperation(invocation, context.CancellationToken);
- if (operation is IAwaitOperation awaitOperation &&
- awaitOperation.Operation is IInvocationOperation configureAwaitOperation &&
- configureAwaitOperation.TargetMethod.IsConfigureAwait(context.Compilation))
+ if (configureAwaitOperation.Arguments.Length != 0 &&
+ configureAwaitOperation.Arguments[0].Value is ILiteralOperation literal &&
+ literal.ConstantValue.Value?.Equals(false) == true)
{
- if (configureAwaitOperation.Arguments.Length != 0 &&
- configureAwaitOperation.Arguments[0].Value is ILiteralOperation literal &&
- literal.ConstantValue.Value?.Equals(false) == true)
- {
- var location = configureAwaitOperation.Syntax.GetLocation();
+ var location = configureAwaitOperation.Syntax.GetLocation();
- // Need to find 'ConfigureAwait' node.
- if (configureAwaitOperation.Syntax is InvocationExpressionSyntax i &&
- i.Expression is
- MemberAccessExpressionSyntax mae)
- {
- // This is a really weird way for getting a location for 'ConfigureAwait(false)' span!
+ // Need to find 'ConfigureAwait' node.
+ if (configureAwaitOperation.Syntax is InvocationExpressionSyntax i &&
+ i.Expression is
+ MemberAccessExpressionSyntax mae)
+ {
+ // This is a really weird way for getting a location for 'ConfigureAwait(false)' span!
- var argsLocation = i.ArgumentList.GetLocation();
- var nameLocation = mae.Name.GetLocation().SourceSpan;
-
- Contract.Assert(argsLocation.SourceTree != null);
- location = Location.Create(argsLocation.SourceTree,
- TextSpan.FromBounds(nameLocation.Start, argsLocation.SourceSpan.End));
- }
+ var argsLocation = i.ArgumentList.GetLocation();
+ var nameLocation = mae.Name.GetLocation().SourceSpan;
- var diagnostic = Diagnostic.Create(Rule, location);
- context.ReportDiagnostic(diagnostic);
+ Contract.Assert(argsLocation.SourceTree != null);
+ location = Location.Create(argsLocation.SourceTree,
+ TextSpan.FromBounds(nameLocation.Start, argsLocation.SourceSpan.End));
}
+
+ var diagnostic = Diagnostic.Create(Rule, location);
+ context.ReportDiagnostic(diagnostic);
}
}
}
diff --git a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/TaskInstanceToStringConversionAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/TaskInstanceToStringConversionAnalyzer.cs
index d236e86..f18b434 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/TaskInstanceToStringConversionAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/AsyncAnalyzers/TaskInstanceToStringConversionAnalyzer.cs
@@ -23,11 +23,11 @@ public TaskInstanceToStringConversionAnalyzer()
{
}
- protected override bool TryCreateDiagnostic(Compilation compilation, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic)
+ protected override bool TryCreateDiagnostic(TaskTypesInfo info, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic)
{
diagnostic = null;
- if (type.IsTaskLike(compilation))
+ if (type.IsTaskLike(info))
{
diagnostic = Diagnostic.Create(Rule, location);
}
diff --git a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/AbstractDefaultToStringImplementationUsageAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/AbstractDefaultToStringImplementationUsageAnalyzer.cs
index a11cd08..222d164 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/AbstractDefaultToStringImplementationUsageAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/AbstractDefaultToStringImplementationUsageAnalyzer.cs
@@ -19,17 +19,21 @@ protected AbstractDefaultToStringImplementationUsageAnalyzer(DiagnosticDescripto
{
}
- protected abstract bool TryCreateDiagnostic(Compilation compilation, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic);
+ protected abstract bool TryCreateDiagnostic(TaskTypesInfo info, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic);
///
protected override void InitializeCore(AnalysisContext context)
{
- context.RegisterOperationAction(AnalyzeConversion, OperationKind.Conversion);
- context.RegisterOperationAction(AnalyzeInterpolation, OperationKind.Interpolation);
- context.RegisterOperationAction(AnalyzeMethodInvocation, OperationKind.Invocation);
+ context.RegisterCompilationStartAction(context =>
+ {
+ var taskTypesInfo = new TaskTypesInfo(context.Compilation);
+ context.RegisterOperationAction(context => AnalyzeConversion(context, taskTypesInfo), OperationKind.Conversion);
+ context.RegisterOperationAction(context => AnalyzeInterpolation(context, taskTypesInfo), OperationKind.Interpolation);
+ context.RegisterOperationAction(context => AnalyzeMethodInvocation(context, taskTypesInfo), OperationKind.Invocation);
+ });
}
- private void AnalyzeMethodInvocation(OperationAnalysisContext context)
+ private void AnalyzeMethodInvocation(OperationAnalysisContext context, TaskTypesInfo info)
{
var methodCall = (IInvocationOperation)context.Operation;
if (methodCall.Instance?.Type is not null &&
@@ -38,7 +42,7 @@ methodCall.TargetMethod.ContainingType.SpecialType is SpecialType.System_Object
.System_ValueType)
{
if (TryCreateDiagnostic(
- context.Compilation,
+ info,
methodCall.Instance.Type,
methodCall.Syntax.GetLocation(),
out var diagnostic))
@@ -48,14 +52,14 @@ methodCall.TargetMethod.ContainingType.SpecialType is SpecialType.System_Object
}
}
- private void AnalyzeInterpolation(OperationAnalysisContext context)
+ private void AnalyzeInterpolation(OperationAnalysisContext context, TaskTypesInfo info)
{
// This method checks for $"foobar: {taskLikeThing}";
if (context.Operation is IInterpolationOperation interpolationOperation)
{
if (interpolationOperation.Expression.Type is not null &&
TryCreateDiagnostic(
- context.Compilation,
+ info,
interpolationOperation.Expression.Type,
interpolationOperation.Expression.Syntax.GetLocation(),
out var diagnostic))
@@ -65,7 +69,7 @@ private void AnalyzeInterpolation(OperationAnalysisContext context)
}
}
- private void AnalyzeConversion(OperationAnalysisContext context)
+ private void AnalyzeConversion(OperationAnalysisContext context, TaskTypesInfo info)
{
// This method checks for "something" + taskLikeThing;
// or string.Format("{0}", taskLikeThing);
@@ -78,7 +82,7 @@ private void AnalyzeConversion(OperationAnalysisContext context)
if ((isToStringConversion(conversion.Parent)) && conversion.Operand.Type is not null)
{
if (TryCreateDiagnostic(
- context.Compilation,
+ info,
conversion.Operand.Type,
context.Operation.Syntax.GetLocation(),
out var diagnostic))
@@ -95,7 +99,7 @@ private void AnalyzeConversion(OperationAnalysisContext context)
conversion.Operand.Type is not null)
{
if (TryCreateDiagnostic(
- context.Compilation,
+ info,
conversion.Operand.Type,
context.Operation.Syntax.GetLocation(),
out var diagnostic))
diff --git a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/DefaultToStringImplementationUsageAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/DefaultToStringImplementationUsageAnalyzer.cs
index 72e3ddf..475762c 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/DefaultToStringImplementationUsageAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/DefaultToStringImplementationUsageAnalyzer.cs
@@ -28,7 +28,7 @@ public DefaultToStringImplementationUsageAnalyzer()
}
///
- protected override bool TryCreateDiagnostic(Compilation compilation, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic)
+ protected override bool TryCreateDiagnostic(TaskTypesInfo info, ITypeSymbol type, Location location, [NotNullWhen(true)]out Diagnostic? diagnostic)
{
diagnostic = null;
diff --git a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/UnobservedResultAnalyzer.cs b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/UnobservedResultAnalyzer.cs
index 9edcb40..e962466 100644
--- a/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/UnobservedResultAnalyzer.cs
+++ b/src/ErrorProne.NET.CoreAnalyzers/CoreAnalyzers/UnobservedResultAnalyzer.cs
@@ -17,6 +17,24 @@ namespace ErrorProne.NET.CoreAnalyzers
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class UnobservedResultAnalyzer : DiagnosticAnalyzer
{
+ private sealed class UnobservedResultAnalyzerInfo
+ {
+ public ImmutableArray ConfigureAwaitMethods { get; }
+ public ImmutableHashSet ContinueWithMethods { get; }
+ public INamedTypeSymbol? ExceptionSymbol { get; }
+ public TaskTypesInfo TaskTypesInfo { get; }
+
+ public UnobservedResultAnalyzerInfo(Compilation compilation)
+ {
+ TaskTypesInfo = new TaskTypesInfo(compilation);
+ ExceptionSymbol = compilation.GetTypeByMetadataName(typeof(Exception).FullName);
+ ConfigureAwaitMethods = TaskTypesInfo.TaskSymbol?.GetMembers("ConfigureAwait").OfType().ToImmutableArray() ?? ImmutableArray.Empty;
+#pragma warning disable RS1024 // Symbols should be compared for equality
+ ContinueWithMethods = TaskTypesInfo.TaskSymbol?.GetMembers("ContinueWith").OfType().ToImmutableHashSet() ?? ImmutableHashSet.Empty;
+#pragma warning restore RS1024 // Symbols should be compared for equality
+ }
+ }
+
private static DiagnosticDescriptor Rule => DiagnosticDescriptors.EPC13;
public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(Rule);
@@ -33,28 +51,31 @@ public override void Initialize(AnalysisContext context)
context.EnableConcurrentExecution();
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
- context.RegisterSyntaxNodeAction(AnalyzeAwaitExpression, SyntaxKind.AwaitExpression);
- context.RegisterSyntaxNodeAction(AnalyzeMethodInvocation, SyntaxKind.InvocationExpression);
+ context.RegisterCompilationStartAction(context =>
+ {
+ var compilation = context.Compilation;
+
+ var info = new UnobservedResultAnalyzerInfo(compilation);
+
+ context.RegisterOperationAction(context => AnalyzeAwaitOperation(context, info), OperationKind.Await);
+ context.RegisterOperationAction(context => AnalyzeMethodInvocation(context, info), OperationKind.Invocation);
+ });
}
- private void AnalyzeMethodInvocation(SyntaxNodeAnalysisContext context)
+ private static void AnalyzeMethodInvocation(OperationAnalysisContext context, UnobservedResultAnalyzerInfo info)
{
- var invocation = (InvocationExpressionSyntax)context.Node;
-
- if (invocation.Parent is ExpressionStatementSyntax ex &&
- context.SemanticModel.GetSymbolInfo(ex.Expression).Symbol is IMethodSymbol ms &&
- TypeMustBeObserved(ms.ReturnType, ms, context.Compilation))
+ var invocationOperation = (IInvocationOperation)context.Operation;
+ var invocation = (InvocationExpressionSyntax)invocationOperation.Syntax;
+ if (invocationOperation.Parent is IExpressionStatementOperation &&
+ TypeMustBeObserved(invocationOperation.TargetMethod.ReturnType, invocationOperation.TargetMethod, info))
{
- var operation = context.SemanticModel.GetOperation(invocation);
- var invocationOperation = operation as IInvocationOperation;
- if (invocationOperation != null &&
- ResultObservedByExtensionMethod(invocationOperation, context.SemanticModel))
+ if (ResultObservedByExtensionMethod(invocationOperation, info))
{
// Result is observed!
return;
}
- if (invocationOperation != null && IsException(invocationOperation.Type, context.Compilation) &&
+ if (IsException(invocationOperation.Type, info) &&
invocationOperation.TargetMethod.ConstructedFrom.ReturnType is ITypeParameterSymbol)
{
// The inferred type is System.Exception (or one of it's derived types),
@@ -63,23 +84,22 @@ private void AnalyzeMethodInvocation(SyntaxNodeAnalysisContext context)
return;
}
- var diagnostic = Diagnostic.Create(Rule, invocation.GetNodeLocationForDiagnostic(), ms.ReturnType.Name);
+ var diagnostic = Diagnostic.Create(Rule, invocation.GetNodeLocationForDiagnostic(), invocationOperation.TargetMethod.ReturnType.Name);
context.ReportDiagnostic(diagnostic);
}
}
- private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
+ private static void AnalyzeAwaitOperation(OperationAnalysisContext context, UnobservedResultAnalyzerInfo info)
{
- var awaitExpression = (AwaitExpressionSyntax)context.Node;
+ var awaitOperation = (IAwaitOperation)context.Operation;
// await can be used on a task value, so the awaited expression may be anything.
- if (awaitExpression.Parent is ExpressionStatementSyntax)
+ if (awaitOperation.Parent is IExpressionStatementOperation)
{
- var operation = context.SemanticModel.GetOperation(awaitExpression);
- if (operation is IAwaitOperation awaitOperation && operation.Type != null && TypeMustBeObserved(operation.Type, null, context.Compilation))
+ if (awaitOperation.Type != null && TypeMustBeObserved(awaitOperation.Type, null, info))
{
if (awaitOperation.Operation is IInvocationOperation invocation &&
- ResultObservedByExtensionMethod(invocation, context.SemanticModel))
+ ResultObservedByExtensionMethod(invocation, info))
{
// Result is observed!
return;
@@ -87,21 +107,21 @@ private void AnalyzeAwaitExpression(SyntaxNodeAnalysisContext context)
// Making an exception for 'Task' case.
// For instance, the following code is totally fine: await Task.WhenAll(t1, t2);
- if (operation.Type.IsTaskLike(context.Compilation))
+ if (awaitOperation.Type.IsTaskLike(info.TaskTypesInfo))
{
return;
}
// Need to extract a real method if this one is 'ConfigureAwait'
- var location = GetLocationForDiagnostic(awaitExpression);
+ var location = GetLocationForDiagnostic((AwaitExpressionSyntax)awaitOperation.Syntax);
- var diagnostic = Diagnostic.Create(Rule, location, operation.Type.Name);
+ var diagnostic = Diagnostic.Create(Rule, location, awaitOperation.Type.Name);
ReportDiagnostic(context, diagnostic);
}
}
}
- private static bool ResultObservedByExtensionMethod(IInvocationOperation operation, SemanticModel semanticModel)
+ private static bool ResultObservedByExtensionMethod(IInvocationOperation operation, UnobservedResultAnalyzerInfo info)
{
// In some cases, the following pattern is used:
// Foo().Handle();
@@ -111,7 +131,7 @@ private static bool ResultObservedByExtensionMethod(IInvocationOperation operati
var methodSymbol = operation.TargetMethod;
// Exception for this rule is 'ConfigureAwait()'
- if (operation.TargetMethod.IsConfigureAwait(semanticModel.Compilation))
+ if (info.ConfigureAwaitMethods.Contains(methodSymbol))
{
return false;
}
@@ -133,7 +153,7 @@ private static Location GetLocationForDiagnostic(AwaitExpressionSyntax awaitExpr
return awaitExpression.GetLocation();
}
- private static void ReportDiagnostic(SyntaxNodeAnalysisContext context, Diagnostic diagnostic)
+ private static void ReportDiagnostic(OperationAnalysisContext context, Diagnostic diagnostic)
{
#if DEBUG
Console.WriteLine($"ERROR: {diagnostic}");
@@ -141,20 +161,20 @@ private static void ReportDiagnostic(SyntaxNodeAnalysisContext context, Diagnost
context.ReportDiagnostic(diagnostic);
}
- private static bool TypeMustBeObserved(ITypeSymbol type, IMethodSymbol? method, Compilation compilation)
+ private static bool TypeMustBeObserved(ITypeSymbol type, IMethodSymbol? method, UnobservedResultAnalyzerInfo info)
{
- if (method?.IsContinueWith(compilation) == true)
+ if (method is not null && info.ContinueWithMethods.Contains(method))
{
// Task.ContinueWith is a bit special.
return false;
}
- return type.EnumerateBaseTypesAndSelf().Any(t => IsObservableType(t, method, compilation));
+ return type.EnumerateBaseTypesAndSelf().Any(t => IsObservableType(t, method, info));
}
- private static bool IsObservableType(ITypeSymbol type, IMethodSymbol? method, Compilation compilation)
+ private static bool IsObservableType(ITypeSymbol type, IMethodSymbol? method, UnobservedResultAnalyzerInfo info)
{
- if (type.IsClrType(compilation, typeof(Exception)))
+ if (type.Equals(info.ExceptionSymbol, SymbolEqualityComparer.Default))
{
// 'ThrowException' method that throws but still returns an exception is quite common.
var methodName = method?.Name;
@@ -171,7 +191,7 @@ private static bool IsObservableType(ITypeSymbol type, IMethodSymbol? method, Co
return true;
}
- if (type.IsClrType(compilation, typeof(Task)))
+ if (type.Equals(info.TaskTypesInfo.TaskSymbol, SymbolEqualityComparer.Default))
{
// Tasks should be observed
return true;
@@ -186,9 +206,9 @@ private static bool IsObservableType(ITypeSymbol type, IMethodSymbol? method, Co
return false;
}
- private static bool IsException(ITypeSymbol? type, Compilation compilation)
+ private static bool IsException(ITypeSymbol? type, UnobservedResultAnalyzerInfo info)
{
- return type.EnumerateBaseTypesAndSelf().Any(t => t.IsClrType(compilation, typeof(Exception)));
+ return type.EnumerateBaseTypesAndSelf().Any(t => t.Equals(info.ExceptionSymbol, SymbolEqualityComparer.Default));
}
}
diff --git a/src/ErrorProne.NET.CoreAnalyzers/ExceptionsAnalyzers/PreconditionsBlock.cs b/src/ErrorProne.NET.CoreAnalyzers/ExceptionsAnalyzers/PreconditionsBlock.cs
deleted file mode 100644
index 26b5810..0000000
--- a/src/ErrorProne.NET.CoreAnalyzers/ExceptionsAnalyzers/PreconditionsBlock.cs
+++ /dev/null
@@ -1,94 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Diagnostics.ContractsLight;
-using System.Linq;
-using ErrorProne.NET.Extensions;
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-
-namespace ErrorProne.NET.ExceptionsAnalyzers
-{
- internal sealed class IfThrowPrecondition
- {
- public IfThrowPrecondition(StatementSyntax ifThrowStatement, ThrowStatementSyntax throwStatement)
- {
- Contract.Requires(ifThrowStatement != null);
- Contract.Requires(throwStatement != null);
-
- IfThrowStatement = ifThrowStatement;
- ThrowStatement = throwStatement;
- }
-
- public StatementSyntax IfThrowStatement { get; }
- public ThrowStatementSyntax ThrowStatement { get; }
- }
-
- ///
- /// Class that holds all checks that could be considered as a method preconditions.
- ///
- internal sealed class PreconditionsBlock
- {
- public PreconditionsBlock(List preconditions)
- {
- Preconditions = preconditions.ToImmutableList();
- }
-
- public ImmutableList Preconditions { get; }
-
- public static PreconditionsBlock GetPreconditions(MethodDeclarationSyntax method, SemanticModel semanticModel)
- {
- Contract.Requires(method != null);
-
- var preconditions = new List();
-
- // Precondition block ends when something exception precondition check is met.
- foreach (var statement in method.Body?.Statements ?? Enumerable.Empty())
- {
- // Currently, If-throw precondition means that
- // if statement has only one statement in the if block
- // and this statement is a throw of type ArgumentException
- var ifThrowStatement = statement as IfStatementSyntax;
- if (ifThrowStatement == null)
- {
- break;
- }
-
- var block = ifThrowStatement.Statement as BlockSyntax;
- if (block != null && block.Statements.Count != 1)
- {
- break;
- }
-
- var throwStatementCandidate = block != null ? block.Statements[0] : ifThrowStatement.Statement;
-
- // The only valid case (when the processing should keep going)
- // is when the if block has one statement and that statement is a throw of ArgumentException
- if (IsThrowArgumentExceptionStatement(throwStatementCandidate, semanticModel))
- {
- preconditions.Add(new IfThrowPrecondition(statement, (ThrowStatementSyntax) throwStatementCandidate));
- }
- else
- {
- break;
- }
- }
-
- return new PreconditionsBlock(preconditions);
- }
-
- private static bool IsThrowArgumentExceptionStatement(StatementSyntax statement, SemanticModel semanticModel)
- {
- var throwStatement = statement as ThrowStatementSyntax;
-
- var objectCreation = throwStatement?.Expression as ObjectCreationExpressionSyntax;
- if (objectCreation == null)
- {
- return false;
- }
-
- var symbol = semanticModel.GetSymbolInfo(objectCreation.Type).Symbol;
- return symbol.IsArgumentExceptionType(semanticModel);
- }
- }
-}
\ No newline at end of file
diff --git a/src/ErrorProne.NET.StructAnalyzers/UseInModifierForReadOnlyStructAnalyzer.cs b/src/ErrorProne.NET.StructAnalyzers/UseInModifierForReadOnlyStructAnalyzer.cs
index dfd0035..8340e27 100644
--- a/src/ErrorProne.NET.StructAnalyzers/UseInModifierForReadOnlyStructAnalyzer.cs
+++ b/src/ErrorProne.NET.StructAnalyzers/UseInModifierForReadOnlyStructAnalyzer.cs
@@ -25,8 +25,13 @@ public override void Initialize(AnalysisContext context)
context.EnableConcurrentExecution();
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
- context.RegisterSymbolAction(AnalyzeNamedType, SymbolKind.NamedType);
- context.RegisterSymbolAction(AnalyzeMethod, SymbolKind.Method);
+ context.RegisterCompilationStartAction(context =>
+ {
+ var taskTypesInfo = new TaskTypesInfo(context.Compilation);
+
+ context.RegisterSymbolAction(context => AnalyzeNamedType(context), SymbolKind.NamedType);
+ context.RegisterSymbolAction(context => AnalyzeMethod(context, taskTypesInfo), SymbolKind.Method);
+ });
}
private void AnalyzeNamedType(SymbolAnalysisContext context)
@@ -51,12 +56,12 @@ private void AnalyzeNamedType(SymbolAnalysisContext context)
}
}
- private void AnalyzeMethod(SymbolAnalysisContext context)
+ private void AnalyzeMethod(SymbolAnalysisContext context, TaskTypesInfo info)
{
context.TryGetSemanticModel(out var semanticModel);
var method = (IMethodSymbol) context.Symbol;
- if (IsOverridenMethod(method) || method.IsAsyncOrTaskBased(context.Compilation) || method.IsIteratorBlock())
+ if (IsOverridenMethod(method) || method.IsAsyncOrTaskBased(info) || method.IsIteratorBlock())
{
// If the method overrides a base method or implements an interface,
// then we can't enforce 'in'-modifier