From b41220ae8918df055f8f3db8e7086daa24bc304b Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Wed, 3 Jun 2026 09:34:04 +0000 Subject: [PATCH 01/16] Optimize qualified name simplification Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../PostProcessing/GeneratedCodeWorkspace.cs | 142 +++++++++++++++++- 1 file changed, 141 insertions(+), 1 deletion(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index d90488b4c85..27916176c7d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -14,6 +14,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Simplification; +using Microsoft.CodeAnalysis.Text; using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Providers; using Microsoft.TypeSpec.Generator.Utilities; @@ -148,7 +149,24 @@ private async Task ProcessDocument(Document document, MemberRemoverRew } document = document.WithSyntaxRoot(root); - document = await Simplifier.ReduceAsync(document); + document = await ReduceQualifiedNamesAsync(document); + root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var simplifierSpans = GetSimplifierSpans(root); + if (simplifierSpans.Count > 0) + { + root = root.WithAdditionalAnnotations(Simplifier.Annotation); + document = document.WithSyntaxRoot(root); + document = await Simplifier.ReduceAsync(document, simplifierSpans); + } + else if (ContainsSimplifierAnnotations(root)) + { + document = await Simplifier.ReduceAsync(document); + } // Reformat if any custom rewriters have been applied if (CodeModelGenerator.Instance.Rewriters.Count > 0) @@ -158,6 +176,128 @@ private async Task ProcessDocument(Document document, MemberRemoverRew return document; } + private static async Task ReduceQualifiedNamesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var safeNodes = new HashSet(); + foreach (var name in root.DescendantNodes().OfType()) + { + if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || + name.Parent is QualifiedNameSyntax || + IsInUnsupportedQualifiedNameContext(name)) + { + continue; + } + + var originalSymbol = semanticModel.GetSymbolInfo(name).Symbol; + if (originalSymbol == null) + { + continue; + } + + var replacement = GetRightmostName(name).WithTriviaFrom(name); + if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol)) + { + safeNodes.Add(name); + } + } + + if (safeNodes.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + safeNodes, + static (_, rewritten) => GetRightmostName(rewritten).WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool SpeculativelyBindsToSameSymbol( + SemanticModel semanticModel, + NameSyntax originalName, + SimpleNameSyntax replacement, + ISymbol originalSymbol) + { + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + return true; + } + + if (originalName.Parent is MemberAccessExpressionSyntax memberAccess && + memberAccess.Expression == originalName) + { + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); + } + + return false; + } + + private static SimpleNameSyntax GetRightmostName(NameSyntax name) => name switch + { + QualifiedNameSyntax qualifiedName => qualifiedName.Right, + AliasQualifiedNameSyntax aliasQualifiedName => aliasQualifiedName.Name, + SimpleNameSyntax simpleName => simpleName, + _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") + }; + + private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => + name.Ancestors().Any(static ancestor => + ancestor is UsingDirectiveSyntax || + ancestor is CrefSyntax); + + private static bool ContainsSimplifierAnnotations(SyntaxNode root) => + root.HasAnnotation(Simplifier.Annotation) || + root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => + nodeOrToken.HasAnnotation(Simplifier.Annotation)); + + private static IReadOnlyList GetSimplifierSpans(SyntaxNode root) + { + List spans = new(); + foreach (var member in root.DescendantNodes().OfType()) + { + if (ContainsReducibleSyntax(member)) + { + spans.Add(member.FullSpan); + } + } + + spans.AddRange(root + .DescendantNodesAndTokens(descendIntoTrivia: true) + .Where(static nodeOrToken => nodeOrToken.HasAnnotation(Simplifier.Annotation)) + .Select(static nodeOrToken => nodeOrToken.FullSpan)); + + return spans; + } + + private static bool ContainsReducibleSyntax(SyntaxNode root) => + root.DescendantNodes( + descendIntoChildren: node => node == root || node is not MemberDeclarationSyntax, + descendIntoTrivia: true).Any(static node => + node is ThisExpressionSyntax || + node is ParenthesizedExpressionSyntax || + node is CrefSyntax || + node is QualifiedNameSyntax || + node is MemberAccessExpressionSyntax || + node is AssignmentExpressionSyntax { RawKind: (int)SyntaxKind.SimpleAssignmentExpression } || + node is AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" }); + public static bool IsGeneratedDocument(Document document) => document.Folders.Contains(GeneratedFolder); public static bool IsCustomDocument(Document document) => !IsGeneratedDocument(document); public static bool IsGeneratedTestDocument(Document document) => document.Folders.Contains(GeneratedTestFolder); From 4870164cb0b8b11ece0789961d0e1484f4d5d522 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 00:55:26 +0000 Subject: [PATCH 02/16] Handle System member qualified name reduction --- .../PostProcessing/GeneratedCodeWorkspace.cs | 180 +++++++++++++++++- .../test/GeneratedCodeWorkspaceTests.cs | 67 +++++++ 2 files changed, 242 insertions(+), 5 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 27916176c7d..4e1f77d4131 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -185,7 +185,7 @@ private static async Task ReduceQualifiedNamesAsync(Document document) return document; } - var safeNodes = new HashSet(); + var safeNameReplacements = new Dictionary(); foreach (var name in root.DescendantNodes().OfType()) { if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || @@ -204,18 +204,45 @@ name.Parent is QualifiedNameSyntax || var replacement = GetRightmostName(name).WithTriviaFrom(name); if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol)) { - safeNodes.Add(name); + safeNameReplacements.Add(name, replacement); } } - if (safeNodes.Count == 0) + foreach (var attribute in root.DescendantNodes().OfType()) + { + if (TryGetAttributeNameReplacement(semanticModel, attribute, out var replacement)) + { + safeNameReplacements[attribute.Name] = replacement; + } + } + + var safeMemberAccessReplacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (IsInUnsupportedQualifiedNameContext(memberAccess)) + { + continue; + } + + if (TryGetMemberAccessReplacement(semanticModel, memberAccess, out var replacement)) + { + safeMemberAccessReplacements.Add(memberAccess, replacement); + } + } + + if (safeNameReplacements.Count == 0 && safeMemberAccessReplacements.Count == 0) { return document; } var rewrittenRoot = root.ReplaceNodes( - safeNodes, - static (_, rewritten) => GetRightmostName(rewritten).WithTriviaFrom(rewritten)); + safeNameReplacements.Keys.Concat(safeMemberAccessReplacements.Keys), + (original, rewritten) => original switch + { + NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), + MemberAccessExpressionSyntax memberAccess => safeMemberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), + _ => rewritten + }); return document.WithSyntaxRoot(rewrittenRoot); } @@ -257,11 +284,154 @@ private static bool SpeculativelyBindsToSameSymbol( _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") }; + private static bool TryGetAttributeNameReplacement( + SemanticModel semanticModel, + AttributeSyntax attribute, + out NameSyntax replacement) + { + replacement = attribute.Name; + if (attribute.Name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(attribute).Symbol; + if (originalSymbol is not IMethodSymbol { ContainingType: { } originalAttributeType }) + { + return false; + } + + var rightmostName = GetRightmostName(attribute.Name); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + attribute.Name.SpanStart, + rightmostName, + SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; + if (!SymbolEqualityComparer.Default.Equals(originalAttributeType, speculativeSymbol)) + { + return false; + } + + replacement = TrimAttributeSuffix(rightmostName).WithTriviaFrom(attribute.Name); + return true; + } + + private static SimpleNameSyntax TrimAttributeSuffix(SimpleNameSyntax name) + { + const string AttributeSuffix = "Attribute"; + var identifier = name.Identifier; + var text = identifier.ValueText; + if (!text.EndsWith(AttributeSuffix, StringComparison.Ordinal) || text.Length == AttributeSuffix.Length) + { + return name; + } + + return SyntaxFactory.IdentifierName( + SyntaxFactory.Identifier( + identifier.LeadingTrivia, + text.Substring(0, text.Length - AttributeSuffix.Length), + identifier.TrailingTrivia)); + } + + private static bool TryGetMemberAccessReplacement( + SemanticModel semanticModel, + MemberAccessExpressionSyntax memberAccess, + out ExpressionSyntax replacement) + { + replacement = memberAccess; + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + if (originalSymbol == null || + !TryGetMemberAccessParts(memberAccess, out var parts) || + parts.Count < 2 || + parts[0].Identifier.ValueText != "System") + { + return false; + } + + for (int i = parts.Count - 1; i > 0; i--) + { + var candidate = BuildMemberAccess(parts, i).WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacement = candidate; + return true; + } + } + + return false; + } + + private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) + { + var builder = new List(); + if (AddMemberAccessParts(expression, builder)) + { + parts = builder; + return true; + } + + parts = []; + return false; + } + + private static bool AddMemberAccessParts(SyntaxNode expression, List parts) + { + switch (expression) + { + case SimpleNameSyntax name: + parts.Add(name); + return true; + case MemberAccessExpressionSyntax memberAccess: + if (!AddMemberAccessParts(memberAccess.Expression, parts)) + { + return false; + } + + parts.Add(memberAccess.Name); + return true; + case QualifiedNameSyntax qualifiedName: + if (!AddMemberAccessParts(qualifiedName.Left, parts)) + { + return false; + } + + parts.Add(qualifiedName.Right); + return true; + case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: + parts.Add(aliasQualifiedName.Name); + return true; + default: + return false; + } + } + + private static ExpressionSyntax BuildMemberAccess(IReadOnlyList parts, int startIndex) + { + ExpressionSyntax expression = parts[startIndex]; + for (int i = startIndex + 1; i < parts.Count; i++) + { + expression = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + expression, + parts[i]); + } + + return expression; + } + private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => name.Ancestors().Any(static ancestor => ancestor is UsingDirectiveSyntax || ancestor is CrefSyntax); + private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax expression) => + expression.Ancestors().Any(static ancestor => + ancestor is CrefSyntax); + private static bool ContainsSimplifierAnnotations(SyntaxNode root) => root.HasAnnotation(Simplifier.Annotation) || root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 0a75d8c9360..8ca35ece393 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -4,6 +4,7 @@ using Microsoft.Build.Construction; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Tests.Common; using NUnit.Framework; using System; @@ -97,6 +98,72 @@ await MockHelpers.LoadMockGeneratorAsync( Assert.NotNull(fooMethod, "Foo method should be found in the SimpleType"); } + [Test] + public async Task GetGeneratedFilesAsync_SimplifiesFrameworkNamesWhenTypeHasSystemMember() + { + MockHelpers.LoadMockGenerator( + outputPath: _projectDir, + configuration: "{\"package-name\": \"TestNamespace\"}", + additionalMetadataReferences: + [ + MetadataReference.CreateFromFile(typeof(EditorBrowsableAttribute).Assembly.Location) + ]); + + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(false); + await workspace.AddGeneratedFile(new CodeFile( + """ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using System; +using System.ComponentModel; + +namespace TestNamespace +{ + public readonly partial struct TestRole : IEquatable + { + private readonly string _value; + private const string SystemValue = "system"; + + public TestRole(string value) + { + _value = value; + } + + public static TestRole System { get; } = new TestRole(SystemValue); + + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] + public override bool Equals(object obj) => obj is TestRole other && Equals(other); + + public bool Equals(TestRole other) => string.Equals(_value, other._value, global::System.StringComparison.InvariantCultureIgnoreCase); + + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] + public override int GetHashCode() => _value != null ? global::System.StringComparer.InvariantCultureIgnoreCase.GetHashCode(_value) : 0; + } +} +""", + "TestRole.cs")); + + string? generatedText = null; + await foreach (var generatedFile in workspace.GetGeneratedFilesAsync()) + { + generatedText = generatedFile.Text; + } + + Assert.That(generatedText, Is.Not.Null); + Assert.That(generatedText, Does.Contain("[EditorBrowsable(EditorBrowsableState.Never)]")); + Assert.That(generatedText, Does.Contain("StringComparison.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Contain("StringComparer.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Not.Contain("System.ComponentModel.EditorBrowsableAttribute")); + Assert.That(generatedText, Does.Not.Contain("System.StringComparison")); + Assert.That(generatedText, Does.Not.Contain("System.StringComparer")); + } + [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { From 59eb312177422a20cb0021d140ba8e9bb768826c Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 01:10:44 +0000 Subject: [PATCH 03/16] Avoid ambiguous System member pre-reduction --- .../PostProcessing/GeneratedCodeWorkspace.cs | 183 +----------------- 1 file changed, 9 insertions(+), 174 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4e1f77d4131..702fb337a64 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -185,11 +185,12 @@ private static async Task ReduceQualifiedNamesAsync(Document document) return document; } - var safeNameReplacements = new Dictionary(); + var safeNodes = new HashSet(); foreach (var name in root.DescendantNodes().OfType()) { if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || name.Parent is QualifiedNameSyntax || + IsPartOfMemberAccessChain(name) || IsInUnsupportedQualifiedNameContext(name)) { continue; @@ -204,45 +205,18 @@ name.Parent is QualifiedNameSyntax || var replacement = GetRightmostName(name).WithTriviaFrom(name); if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol)) { - safeNameReplacements.Add(name, replacement); + safeNodes.Add(name); } } - foreach (var attribute in root.DescendantNodes().OfType()) - { - if (TryGetAttributeNameReplacement(semanticModel, attribute, out var replacement)) - { - safeNameReplacements[attribute.Name] = replacement; - } - } - - var safeMemberAccessReplacements = new Dictionary(); - foreach (var memberAccess in root.DescendantNodes().OfType()) - { - if (IsInUnsupportedQualifiedNameContext(memberAccess)) - { - continue; - } - - if (TryGetMemberAccessReplacement(semanticModel, memberAccess, out var replacement)) - { - safeMemberAccessReplacements.Add(memberAccess, replacement); - } - } - - if (safeNameReplacements.Count == 0 && safeMemberAccessReplacements.Count == 0) + if (safeNodes.Count == 0) { return document; } var rewrittenRoot = root.ReplaceNodes( - safeNameReplacements.Keys.Concat(safeMemberAccessReplacements.Keys), - (original, rewritten) => original switch - { - NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), - MemberAccessExpressionSyntax memberAccess => safeMemberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), - _ => rewritten - }); + safeNodes, + static (_, rewritten) => GetRightmostName(rewritten).WithTriviaFrom(rewritten)); return document.WithSyntaxRoot(rewrittenRoot); } @@ -284,154 +258,15 @@ private static bool SpeculativelyBindsToSameSymbol( _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") }; - private static bool TryGetAttributeNameReplacement( - SemanticModel semanticModel, - AttributeSyntax attribute, - out NameSyntax replacement) - { - replacement = attribute.Name; - if (attribute.Name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax) - { - return false; - } - - var originalSymbol = semanticModel.GetSymbolInfo(attribute).Symbol; - if (originalSymbol is not IMethodSymbol { ContainingType: { } originalAttributeType }) - { - return false; - } - - var rightmostName = GetRightmostName(attribute.Name); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - attribute.Name.SpanStart, - rightmostName, - SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; - if (!SymbolEqualityComparer.Default.Equals(originalAttributeType, speculativeSymbol)) - { - return false; - } - - replacement = TrimAttributeSuffix(rightmostName).WithTriviaFrom(attribute.Name); - return true; - } - - private static SimpleNameSyntax TrimAttributeSuffix(SimpleNameSyntax name) - { - const string AttributeSuffix = "Attribute"; - var identifier = name.Identifier; - var text = identifier.ValueText; - if (!text.EndsWith(AttributeSuffix, StringComparison.Ordinal) || text.Length == AttributeSuffix.Length) - { - return name; - } - - return SyntaxFactory.IdentifierName( - SyntaxFactory.Identifier( - identifier.LeadingTrivia, - text.Substring(0, text.Length - AttributeSuffix.Length), - identifier.TrailingTrivia)); - } - - private static bool TryGetMemberAccessReplacement( - SemanticModel semanticModel, - MemberAccessExpressionSyntax memberAccess, - out ExpressionSyntax replacement) - { - replacement = memberAccess; - var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; - if (originalSymbol == null || - !TryGetMemberAccessParts(memberAccess, out var parts) || - parts.Count < 2 || - parts[0].Identifier.ValueText != "System") - { - return false; - } - - for (int i = parts.Count - 1; i > 0; i--) - { - var candidate = BuildMemberAccess(parts, i).WithTriviaFrom(memberAccess); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - memberAccess.SpanStart, - candidate, - SpeculativeBindingOption.BindAsExpression).Symbol; - if (speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) - { - replacement = candidate; - return true; - } - } - - return false; - } - - private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) - { - var builder = new List(); - if (AddMemberAccessParts(expression, builder)) - { - parts = builder; - return true; - } - - parts = []; - return false; - } - - private static bool AddMemberAccessParts(SyntaxNode expression, List parts) - { - switch (expression) - { - case SimpleNameSyntax name: - parts.Add(name); - return true; - case MemberAccessExpressionSyntax memberAccess: - if (!AddMemberAccessParts(memberAccess.Expression, parts)) - { - return false; - } - - parts.Add(memberAccess.Name); - return true; - case QualifiedNameSyntax qualifiedName: - if (!AddMemberAccessParts(qualifiedName.Left, parts)) - { - return false; - } - - parts.Add(qualifiedName.Right); - return true; - case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: - parts.Add(aliasQualifiedName.Name); - return true; - default: - return false; - } - } - - private static ExpressionSyntax BuildMemberAccess(IReadOnlyList parts, int startIndex) - { - ExpressionSyntax expression = parts[startIndex]; - for (int i = startIndex + 1; i < parts.Count; i++) - { - expression = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - expression, - parts[i]); - } - - return expression; - } + private static bool IsPartOfMemberAccessChain(NameSyntax name) => + name.Parent is MemberAccessExpressionSyntax || + name.Ancestors().OfType().Any(); private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => name.Ancestors().Any(static ancestor => ancestor is UsingDirectiveSyntax || ancestor is CrefSyntax); - private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax expression) => - expression.Ancestors().Any(static ancestor => - ancestor is CrefSyntax); - private static bool ContainsSimplifierAnnotations(SyntaxNode root) => root.HasAnnotation(Simplifier.Annotation) || root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => From 5bbd4eb7e1004349a76c0973424918e2bda6e557 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 01:14:14 +0000 Subject: [PATCH 04/16] Revert "Avoid ambiguous System member pre-reduction" This reverts commit 59eb312177422a20cb0021d140ba8e9bb768826c. --- .../PostProcessing/GeneratedCodeWorkspace.cs | 183 +++++++++++++++++- 1 file changed, 174 insertions(+), 9 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 702fb337a64..4e1f77d4131 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -185,12 +185,11 @@ private static async Task ReduceQualifiedNamesAsync(Document document) return document; } - var safeNodes = new HashSet(); + var safeNameReplacements = new Dictionary(); foreach (var name in root.DescendantNodes().OfType()) { if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || name.Parent is QualifiedNameSyntax || - IsPartOfMemberAccessChain(name) || IsInUnsupportedQualifiedNameContext(name)) { continue; @@ -205,18 +204,45 @@ name.Parent is QualifiedNameSyntax || var replacement = GetRightmostName(name).WithTriviaFrom(name); if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol)) { - safeNodes.Add(name); + safeNameReplacements.Add(name, replacement); } } - if (safeNodes.Count == 0) + foreach (var attribute in root.DescendantNodes().OfType()) + { + if (TryGetAttributeNameReplacement(semanticModel, attribute, out var replacement)) + { + safeNameReplacements[attribute.Name] = replacement; + } + } + + var safeMemberAccessReplacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (IsInUnsupportedQualifiedNameContext(memberAccess)) + { + continue; + } + + if (TryGetMemberAccessReplacement(semanticModel, memberAccess, out var replacement)) + { + safeMemberAccessReplacements.Add(memberAccess, replacement); + } + } + + if (safeNameReplacements.Count == 0 && safeMemberAccessReplacements.Count == 0) { return document; } var rewrittenRoot = root.ReplaceNodes( - safeNodes, - static (_, rewritten) => GetRightmostName(rewritten).WithTriviaFrom(rewritten)); + safeNameReplacements.Keys.Concat(safeMemberAccessReplacements.Keys), + (original, rewritten) => original switch + { + NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), + MemberAccessExpressionSyntax memberAccess => safeMemberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), + _ => rewritten + }); return document.WithSyntaxRoot(rewrittenRoot); } @@ -258,15 +284,154 @@ private static bool SpeculativelyBindsToSameSymbol( _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") }; - private static bool IsPartOfMemberAccessChain(NameSyntax name) => - name.Parent is MemberAccessExpressionSyntax || - name.Ancestors().OfType().Any(); + private static bool TryGetAttributeNameReplacement( + SemanticModel semanticModel, + AttributeSyntax attribute, + out NameSyntax replacement) + { + replacement = attribute.Name; + if (attribute.Name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(attribute).Symbol; + if (originalSymbol is not IMethodSymbol { ContainingType: { } originalAttributeType }) + { + return false; + } + + var rightmostName = GetRightmostName(attribute.Name); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + attribute.Name.SpanStart, + rightmostName, + SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; + if (!SymbolEqualityComparer.Default.Equals(originalAttributeType, speculativeSymbol)) + { + return false; + } + + replacement = TrimAttributeSuffix(rightmostName).WithTriviaFrom(attribute.Name); + return true; + } + + private static SimpleNameSyntax TrimAttributeSuffix(SimpleNameSyntax name) + { + const string AttributeSuffix = "Attribute"; + var identifier = name.Identifier; + var text = identifier.ValueText; + if (!text.EndsWith(AttributeSuffix, StringComparison.Ordinal) || text.Length == AttributeSuffix.Length) + { + return name; + } + + return SyntaxFactory.IdentifierName( + SyntaxFactory.Identifier( + identifier.LeadingTrivia, + text.Substring(0, text.Length - AttributeSuffix.Length), + identifier.TrailingTrivia)); + } + + private static bool TryGetMemberAccessReplacement( + SemanticModel semanticModel, + MemberAccessExpressionSyntax memberAccess, + out ExpressionSyntax replacement) + { + replacement = memberAccess; + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + if (originalSymbol == null || + !TryGetMemberAccessParts(memberAccess, out var parts) || + parts.Count < 2 || + parts[0].Identifier.ValueText != "System") + { + return false; + } + + for (int i = parts.Count - 1; i > 0; i--) + { + var candidate = BuildMemberAccess(parts, i).WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacement = candidate; + return true; + } + } + + return false; + } + + private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) + { + var builder = new List(); + if (AddMemberAccessParts(expression, builder)) + { + parts = builder; + return true; + } + + parts = []; + return false; + } + + private static bool AddMemberAccessParts(SyntaxNode expression, List parts) + { + switch (expression) + { + case SimpleNameSyntax name: + parts.Add(name); + return true; + case MemberAccessExpressionSyntax memberAccess: + if (!AddMemberAccessParts(memberAccess.Expression, parts)) + { + return false; + } + + parts.Add(memberAccess.Name); + return true; + case QualifiedNameSyntax qualifiedName: + if (!AddMemberAccessParts(qualifiedName.Left, parts)) + { + return false; + } + + parts.Add(qualifiedName.Right); + return true; + case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: + parts.Add(aliasQualifiedName.Name); + return true; + default: + return false; + } + } + + private static ExpressionSyntax BuildMemberAccess(IReadOnlyList parts, int startIndex) + { + ExpressionSyntax expression = parts[startIndex]; + for (int i = startIndex + 1; i < parts.Count; i++) + { + expression = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + expression, + parts[i]); + } + + return expression; + } private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => name.Ancestors().Any(static ancestor => ancestor is UsingDirectiveSyntax || ancestor is CrefSyntax); + private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax expression) => + expression.Ancestors().Any(static ancestor => + ancestor is CrefSyntax); + private static bool ContainsSimplifierAnnotations(SyntaxNode root) => root.HasAnnotation(Simplifier.Annotation) || root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => From 47fbfa978ade1d0d9efe53d4da80cc42dfd1aab2 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 01:22:58 +0000 Subject: [PATCH 05/16] Experiment with manual generated name reduction --- .../PostProcessing/GeneratedCodeWorkspace.cs | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4e1f77d4131..4e099783355 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -12,6 +12,7 @@ using Microsoft.CodeAnalysis; using MSBuildProjectCollection = Microsoft.Build.Evaluation.ProjectCollection; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Simplification; using Microsoft.CodeAnalysis.Text; @@ -150,23 +151,7 @@ private async Task ProcessDocument(Document document, MemberRemoverRew document = document.WithSyntaxRoot(root); document = await ReduceQualifiedNamesAsync(document); - root = await document.GetSyntaxRootAsync(); - if (root == null) - { - return document; - } - - var simplifierSpans = GetSimplifierSpans(root); - if (simplifierSpans.Count > 0) - { - root = root.WithAdditionalAnnotations(Simplifier.Annotation); - document = document.WithSyntaxRoot(root); - document = await Simplifier.ReduceAsync(document, simplifierSpans); - } - else if (ContainsSimplifierAnnotations(root)) - { - document = await Simplifier.ReduceAsync(document); - } + document = await ReduceParenthesizedAssignmentsAsync(document); // Reformat if any custom rewriters have been applied if (CodeModelGenerator.Instance.Rewriters.Count > 0) @@ -246,6 +231,31 @@ name.Parent is QualifiedNameSyntax || return document.WithSyntaxRoot(rewrittenRoot); } + private static async Task ReduceParenthesizedAssignmentsAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var assignments = root.DescendantNodes() + .OfType() + .Where(static node => + node.Expression is AssignmentExpressionSyntax && + node.Parent is ExpressionStatementSyntax) + .ToList(); + if (assignments.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + assignments, + static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + private static bool SpeculativelyBindsToSameSymbol( SemanticModel semanticModel, NameSyntax originalName, From f25c028960ffe388c785a6efb189110ca4fa314e Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 01:54:17 +0000 Subject: [PATCH 06/16] Reduce global aliases in manual name experiment --- .../PostProcessing/GeneratedCodeWorkspace.cs | 153 ++++++++++++- .../test/GeneratedCodeWorkspaceTests.cs | 213 ++++++++++++++++-- 2 files changed, 344 insertions(+), 22 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4e099783355..4abbfb665c4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -152,6 +152,7 @@ private async Task ProcessDocument(Document document, MemberRemoverRew document = await ReduceQualifiedNamesAsync(document); document = await ReduceParenthesizedAssignmentsAsync(document); + document = await ReduceDocumentationGlobalAliasesAsync(document); // Reformat if any custom rewriters have been applied if (CodeModelGenerator.Instance.Rewriters.Count > 0) @@ -170,6 +171,45 @@ private static async Task ReduceQualifiedNamesAsync(Document document) return document; } + var globalAliases = new Dictionary(); + foreach (var aliasName in root.DescendantNodes().OfType()) + { + if (aliasName.Alias.Identifier.ValueText != "global" || + aliasName.Ancestors().Any(static ancestor => + ancestor is AttributeSyntax || + ancestor is MemberAccessExpressionSyntax) || + IsInUnsupportedQualifiedNameContext(aliasName)) + { + continue; + } + + var originalSymbol = GetSymbol(semanticModel, aliasName); + if (originalSymbol == null) + { + continue; + } + + var replacement = aliasName.Name.WithTriviaFrom(aliasName); + if (SpeculativelyBindsToSameSymbol(semanticModel, aliasName, replacement, originalSymbol)) + { + globalAliases.Add(aliasName, replacement); + } + } + + if (globalAliases.Count > 0) + { + root = root.ReplaceNodes( + globalAliases.Keys, + (original, rewritten) => globalAliases[original].WithTriviaFrom(rewritten)); + document = document.WithSyntaxRoot(root); + semanticModel = await document.GetSemanticModelAsync(); + root = await document.GetSyntaxRootAsync(); + if (root == null || semanticModel == null) + { + return document; + } + } + var safeNameReplacements = new Dictionary(); foreach (var name in root.DescendantNodes().OfType()) { @@ -180,14 +220,13 @@ name.Parent is QualifiedNameSyntax || continue; } - var originalSymbol = semanticModel.GetSymbolInfo(name).Symbol; + var originalSymbol = GetSymbol(semanticModel, name); if (originalSymbol == null) { continue; } - var replacement = GetRightmostName(name).WithTriviaFrom(name); - if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol)) + if (TryGetNameReplacement(semanticModel, name, originalSymbol, out var replacement)) { safeNameReplacements.Add(name, replacement); } @@ -231,6 +270,10 @@ name.Parent is QualifiedNameSyntax || return document.WithSyntaxRoot(rewrittenRoot); } + private static ISymbol? GetSymbol(SemanticModel semanticModel, NameSyntax name) => + semanticModel.GetSymbolInfo(name).Symbol ?? + semanticModel.GetTypeInfo(name).Type; + private static async Task ReduceParenthesizedAssignmentsAsync(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -256,10 +299,65 @@ node.Expression is AssignmentExpressionSyntax && return document.WithSyntaxRoot(rewrittenRoot); } + private static async Task ReduceDocumentationGlobalAliasesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var documentationTrivia = root.DescendantTrivia(descendIntoTrivia: true) + .Where(static trivia => + trivia.HasStructure && + trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) && + trivia.ToFullString().Contains("global::", StringComparison.Ordinal)) + .ToList(); + if (documentationTrivia.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceTrivia( + documentationTrivia, + static (_, rewritten) => + { + var reduced = rewritten.ToFullString().Replace("global::", string.Empty, StringComparison.Ordinal); + var parsedTrivia = SyntaxFactory.ParseLeadingTrivia(reduced); + return parsedTrivia.Count == 1 ? parsedTrivia[0] : rewritten; + }); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool TryGetNameReplacement( + SemanticModel semanticModel, + NameSyntax originalName, + ISymbol originalSymbol, + out NameSyntax replacement) + { + replacement = originalName; + if (!TryGetNameParts(originalName, out var parts)) + { + return false; + } + + for (int i = parts.Count - 1; i >= 0; i--) + { + var candidate = BuildName(parts, i).WithTriviaFrom(originalName); + if (SpeculativelyBindsToSameSymbol(semanticModel, originalName, candidate, originalSymbol)) + { + replacement = candidate; + return true; + } + } + + return false; + } + private static bool SpeculativelyBindsToSameSymbol( SemanticModel semanticModel, NameSyntax originalName, - SimpleNameSyntax replacement, + NameSyntax replacement, ISymbol originalSymbol) { var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( @@ -286,6 +384,53 @@ private static bool SpeculativelyBindsToSameSymbol( return false; } + private static bool TryGetNameParts(NameSyntax name, out IReadOnlyList parts) + { + var builder = new List(); + if (AddNameParts(name, builder)) + { + parts = builder; + return true; + } + + parts = []; + return false; + } + + private static bool AddNameParts(NameSyntax name, List parts) + { + switch (name) + { + case SimpleNameSyntax simpleName: + parts.Add(simpleName); + return true; + case QualifiedNameSyntax qualifiedName: + if (!AddNameParts(qualifiedName.Left, parts)) + { + return false; + } + + parts.Add(qualifiedName.Right); + return true; + case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: + parts.Add(aliasQualifiedName.Name); + return true; + default: + return false; + } + } + + private static NameSyntax BuildName(IReadOnlyList parts, int startIndex) + { + NameSyntax name = parts[startIndex]; + for (int i = startIndex + 1; i < parts.Count; i++) + { + name = SyntaxFactory.QualifiedName(name, parts[i]); + } + + return name; + } + private static SimpleNameSyntax GetRightmostName(NameSyntax name) => name switch { QualifiedNameSyntax qualifiedName => qualifiedName.Right, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 8ca35ece393..4dc0d98548e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -101,17 +101,7 @@ await MockHelpers.LoadMockGeneratorAsync( [Test] public async Task GetGeneratedFilesAsync_SimplifiesFrameworkNamesWhenTypeHasSystemMember() { - MockHelpers.LoadMockGenerator( - outputPath: _projectDir, - configuration: "{\"package-name\": \"TestNamespace\"}", - additionalMetadataReferences: - [ - MetadataReference.CreateFromFile(typeof(EditorBrowsableAttribute).Assembly.Location) - ]); - - GeneratedCodeWorkspace.Initialize(); - var workspace = await GeneratedCodeWorkspace.Create(false); - await workspace.AddGeneratedFile(new CodeFile( + var generatedText = await ProcessGeneratedCodeAsync( """ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -147,13 +137,7 @@ public TestRole(string value) } } """, - "TestRole.cs")); - - string? generatedText = null; - await foreach (var generatedFile in workspace.GetGeneratedFilesAsync()) - { - generatedText = generatedFile.Text; - } + typeof(EditorBrowsableAttribute).Assembly.Location); Assert.That(generatedText, Is.Not.Null); Assert.That(generatedText, Does.Contain("[EditorBrowsable(EditorBrowsableState.Never)]")); @@ -164,6 +148,178 @@ public TestRole(string value) Assert.That(generatedText, Does.Not.Contain("System.StringComparer")); } + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenImportedNamespacesContainSameTypeName() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using First; +using Second; + +namespace TestNamespace +{ + public class Container + { + public global::First.Conflict FirstValue { get; } + public global::Second.Conflict SecondValue { get; } + } +} + +namespace First +{ + public class Conflict { } +} + +namespace Second +{ + public class Conflict { } +} +"""); + + Assert.That(generatedText, Does.Contain("First.Conflict FirstValue")); + Assert.That(generatedText, Does.Contain("Second.Conflict SecondValue")); + Assert.That(generatedText, Does.Not.Contain("public Conflict FirstValue")); + Assert.That(generatedText, Does.Not.Contain("public Conflict SecondValue")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesFrameworkQualificationWhenGeneratedModelShadowsFrameworkType() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + public class BinaryData { } + + public class Container + { + public global::System.BinaryData Payload { get; } + } +} +"""); + + Assert.That(generatedText, Does.Contain("System.BinaryData Payload")); + Assert.That(generatedText, Does.Not.Contain("public BinaryData Payload")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenCurrentNamespaceTypeShadowsImportedType() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using External; + +namespace TestNamespace +{ + public class Widget { } + + public class Container + { + public global::External.Widget ExternalWidget { get; } + } +} + +namespace External +{ + public class Widget { } +} +"""); + + Assert.That(generatedText, Does.Contain("External.Widget ExternalWidget")); + Assert.That(generatedText, Does.Not.Contain("public Widget ExternalWidget")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesQualificationWhenParameterNameConflictsWithTypeName() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + public class Container + { + public bool Equals(string StringComparison) => global::System.StringComparison.InvariantCultureIgnoreCase.Equals(StringComparison, StringComparison); + } +} +"""); + + Assert.That(generatedText, Does.Contain("System.StringComparison.InvariantCultureIgnoreCase")); + Assert.That(generatedText, Does.Not.Contain("=> StringComparison.InvariantCultureIgnoreCase")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesGlobalAliasesInXmlDocCrefsSeparatelyFromCode() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; + +namespace TestNamespace +{ + /// See . + public class Container + { + public global::System.ArgumentNullException Create() => null; + } +} +"""); + + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("ArgumentNullException Create()")); + Assert.That(generatedText, Does.Not.Contain("global::System.ArgumentNullException")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesAliasesGenericNamesAndCustomizationTypesSafely() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.Collections.Generic; +using AliasWidget = Customization.Widget; + +namespace TestNamespace +{ + public class Container + { + public global::System.Collections.Generic.IList Widgets { get; } + public global::Customization.Widget Create(AliasWidget widget) => widget; + } +} + +namespace Customization +{ + public class Widget { } +} +"""); + + Assert.That(generatedText, Does.Contain("IList Widgets")); + Assert.That(generatedText, Does.Contain("Customization.Widget Create(AliasWidget widget)")); + Assert.That(generatedText, Does.Not.Contain("global::System.Collections.Generic.IList")); + Assert.That(generatedText, Does.Not.Contain("global::Customization.Widget")); + } + [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { @@ -403,6 +559,27 @@ public class Placeholder {{ }} return dllPath; } + private async Task ProcessGeneratedCodeAsync(string content, params string[] additionalMetadataReferencePaths) + { + MockHelpers.LoadMockGenerator( + outputPath: _projectDir, + configuration: "{\"package-name\": \"TestNamespace\"}", + additionalMetadataReferences: additionalMetadataReferencePaths.Select(static path => MetadataReference.CreateFromFile(path))); + + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(false); + await workspace.AddGeneratedFile(new CodeFile(content, "TestFile.cs")); + + string? generatedText = null; + await foreach (var generatedFile in workspace.GetGeneratedFilesAsync()) + { + generatedText = generatedFile.Text; + } + + Assert.That(generatedText, Is.Not.Null); + return generatedText!; + } + private void CreateTestAssemblyAndProjectFile(string nugetCacheDir, string csProjectFileName) { var ns = csProjectFileName.StartsWith("TestNamespaceUnevaluatedFrameworkValue") From bbf03058b583937511018d8279e14bdd8b527306 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 02:05:48 +0000 Subject: [PATCH 07/16] Handle global aliases in cast type arguments --- .../src/PostProcessing/GeneratedCodeWorkspace.cs | 8 ++++++-- .../test/GeneratedCodeWorkspaceTests.cs | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4abbfb665c4..77a53edaadb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -176,8 +176,8 @@ private static async Task ReduceQualifiedNamesAsync(Document document) { if (aliasName.Alias.Identifier.ValueText != "global" || aliasName.Ancestors().Any(static ancestor => - ancestor is AttributeSyntax || - ancestor is MemberAccessExpressionSyntax) || + ancestor is AttributeSyntax) || + IsInMemberAccessExpressionChain(aliasName) || IsInUnsupportedQualifiedNameContext(aliasName)) { continue; @@ -587,6 +587,10 @@ private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax express expression.Ancestors().Any(static ancestor => ancestor is CrefSyntax); + private static bool IsInMemberAccessExpressionChain(NameSyntax name) => + name.Ancestors().Any(static ancestor => ancestor is MemberAccessExpressionSyntax) && + !name.Ancestors().Any(static ancestor => ancestor is TypeSyntax); + private static bool ContainsSimplifierAnnotations(SyntaxNode root) => root.HasAnnotation(Simplifier.Annotation) || root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 4dc0d98548e..9254cee9ecc 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -305,6 +305,12 @@ public class Container { public global::System.Collections.Generic.IList Widgets { get; } public global::Customization.Widget Create(AliasWidget widget) => widget; + public string GetFormat() => ((IPersistableModel)this).GetFormatFromOptions(null); + } + + public interface IPersistableModel + { + string GetFormatFromOptions(object options); } } @@ -316,6 +322,7 @@ public class Widget { } Assert.That(generatedText, Does.Contain("IList Widgets")); Assert.That(generatedText, Does.Contain("Customization.Widget Create(AliasWidget widget)")); + Assert.That(generatedText, Does.Contain("((IPersistableModel)this).GetFormatFromOptions(null)")); Assert.That(generatedText, Does.Not.Contain("global::System.Collections.Generic.IList")); Assert.That(generatedText, Does.Not.Contain("global::Customization.Widget")); } From 8e5c56931b34aea315889127707102e277c556ff Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 02:35:51 +0000 Subject: [PATCH 08/16] Reduce qualified names iteratively --- .../PostProcessing/GeneratedCodeWorkspace.cs | 75 +++++++++-- .../test/GeneratedCodeWorkspaceTests.cs | 120 +++++++++++++++++- 2 files changed, 186 insertions(+), 9 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 77a53edaadb..281deb45a8d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.IO; using System.Linq; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Build.Construction; @@ -150,9 +151,19 @@ private async Task ProcessDocument(Document document, MemberRemoverRew } document = document.WithSyntaxRoot(root); - document = await ReduceQualifiedNamesAsync(document); + for (int i = 0; i < 8; i++) + { + var reducedDocument = await ReduceQualifiedNamesAsync(document); + if (ReferenceEquals(reducedDocument, document)) + { + break; + } + + document = reducedDocument; + } + document = await ReduceParenthesizedAssignmentsAsync(document); - document = await ReduceDocumentationGlobalAliasesAsync(document); + document = await ReduceDocumentationQualifiedNamesAsync(document); // Reformat if any custom rewriters have been applied if (CodeModelGenerator.Instance.Rewriters.Count > 0) @@ -299,7 +310,7 @@ node.Expression is AssignmentExpressionSyntax && return document.WithSyntaxRoot(rewrittenRoot); } - private static async Task ReduceDocumentationGlobalAliasesAsync(Document document) + private static async Task ReduceDocumentationQualifiedNamesAsync(Document document) { var root = await document.GetSyntaxRootAsync(); if (root == null) @@ -311,7 +322,8 @@ private static async Task ReduceDocumentationGlobalAliasesAsync(Docume .Where(static trivia => trivia.HasStructure && trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) && - trivia.ToFullString().Contains("global::", StringComparison.Ordinal)) + (trivia.ToFullString().Contains("global::", StringComparison.Ordinal) || + trivia.ToFullString().Contains("cref=\"", StringComparison.Ordinal))) .ToList(); if (documentationTrivia.Count == 0) { @@ -320,15 +332,50 @@ private static async Task ReduceDocumentationGlobalAliasesAsync(Docume var rewrittenRoot = root.ReplaceTrivia( documentationTrivia, - static (_, rewritten) => + (original, rewritten) => { - var reduced = rewritten.ToFullString().Replace("global::", string.Empty, StringComparison.Ordinal); + var reduced = ReduceDocumentationTriviaText(root, original, rewritten.ToFullString()); var parsedTrivia = SyntaxFactory.ParseLeadingTrivia(reduced); return parsedTrivia.Count == 1 ? parsedTrivia[0] : rewritten; }); return document.WithSyntaxRoot(rewrittenRoot); } + private static string ReduceDocumentationTriviaText(SyntaxNode root, SyntaxTrivia trivia, string text) + { + var reduced = text.Replace("global::", string.Empty, StringComparison.Ordinal); + var namespacePrefix = trivia.Token.Parent? + .AncestorsAndSelf() + .OfType() + .FirstOrDefault()? + .Name + .ToString(); + + var namespacePrefixes = root.DescendantNodes() + .OfType() + .Where(static directive => directive is { Alias: null, StaticKeyword.RawKind: 0, Name: not null }) + .Select(static directive => directive.Name!.ToString()) + .Append(namespacePrefix) + .Where(static prefix => !string.IsNullOrEmpty(prefix)) + .Distinct(StringComparer.Ordinal) + .OrderByDescending(static prefix => prefix!.Length); + + var prefixes = namespacePrefixes.Select(static prefix => prefix! + ".").ToArray(); + return Regex.Replace( + reduced, + @"(?(?:cref|name)="")(?[^""]*)(?"")", + match => + { + var value = match.Groups["value"].Value; + foreach (var prefix in prefixes) + { + value = value.Replace(prefix, string.Empty, StringComparison.Ordinal); + } + + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + }); + } + private static bool TryGetNameReplacement( SemanticModel semanticModel, NameSyntax originalName, @@ -370,6 +417,19 @@ private static bool SpeculativelyBindsToSameSymbol( return true; } + if (originalSymbol is ITypeSymbol originalType) + { + var speculativeType = semanticModel.GetSpeculativeTypeInfo( + originalName.SpanStart, + replacement, + SpeculativeBindingOption.BindAsTypeOrNamespace).Type; + if (speculativeType != null && + SymbolEqualityComparer.Default.Equals(originalType, speculativeType)) + { + return true; + } + } + if (originalName.Parent is MemberAccessExpressionSyntax memberAccess && memberAccess.Expression == originalName) { @@ -496,8 +556,7 @@ private static bool TryGetMemberAccessReplacement( var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; if (originalSymbol == null || !TryGetMemberAccessParts(memberAccess, out var parts) || - parts.Count < 2 || - parts[0].Identifier.ValueText != "System") + parts.Count < 2) { return false; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 9254cee9ecc..d6773de36fb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -283,11 +283,38 @@ public class Container } """); - Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("")); Assert.That(generatedText, Does.Contain("ArgumentNullException Create()")); Assert.That(generatedText, Does.Not.Contain("global::System.ArgumentNullException")); } + [Test] + public async Task GetGeneratedFilesAsync_ReducesQualifiedXmlDocCrefs() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System; +using System.Text.Json; + +namespace TestNamespace +{ + /// See and . + public class Widget + { + public JsonSerializerOptions Options { get; } + } +} +"""); + + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Not.Contain("TestNamespace.Widget")); + Assert.That(generatedText, Does.Not.Contain("System.Text.Json.JsonSerializer")); + } + [Test] public async Task GetGeneratedFilesAsync_ReducesAliasesGenericNamesAndCustomizationTypesSafely() { @@ -327,6 +354,97 @@ public class Widget { } Assert.That(generatedText, Does.Not.Contain("global::Customization.Widget")); } + [Test] + public async Task GetGeneratedFilesAsync_ReducesQualifiedGenericTypeNames() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.ClientModel; +using System.Threading.Tasks; + +namespace TestNamespace +{ + public class Widget { } + + public class Operations + { + public Task> GetAsync() => null; + } +} +""", + typeof(System.ClientModel.ClientResult).Assembly.Location); + + Assert.That(generatedText, Does.Contain("Task> GetAsync()")); + Assert.That(generatedText, Does.Not.Contain("System.ClientModel.ClientResult")); + Assert.That(generatedText, Does.Not.Contain("TestNamespace.Widget")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesSameNamespaceStaticMemberAccess() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public void Invoke(string value) + { + TestNamespace.Argument.AssertNotNull(value, nameof(value)); + } + } + + internal static class Argument + { + public static void AssertNotNull(object value, string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Argument.AssertNotNull(value, nameof(value));")); + Assert.That(generatedText, Does.Not.Contain("TestNamespace.Argument.AssertNotNull")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesStaticMemberQualificationWhenShortNameConflicts() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public void Invoke(string value) + { + var Argument = new LocalArgument(); + TestNamespace.Argument.AssertNotNull(value, nameof(value)); + } + } + + internal static class Argument + { + public static void AssertNotNull(object value, string name) { } + } + + internal class LocalArgument + { + public void AssertNotNull(object value, string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("TestNamespace.Argument.AssertNotNull(value, nameof(value));")); + } + [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { From d867864a0057c0988e4047b2347257bcdb90b3f1 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 03:42:10 +0000 Subject: [PATCH 09/16] Match generated baselines without Roslyn simplifier --- .../PostProcessing/GeneratedCodeWorkspace.cs | 369 +++++++++++++++++- .../test/GeneratedCodeWorkspaceTests.cs | 162 +++++++- 2 files changed, 520 insertions(+), 11 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 281deb45a8d..9aa35274248 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -162,7 +162,30 @@ private async Task ProcessDocument(Document document, MemberRemoverRew document = reducedDocument; } - document = await ReduceParenthesizedAssignmentsAsync(document); + for (int i = 0; i < 8; i++) + { + var reducedDocument = await ReduceParenthesesAsync(document); + if (ReferenceEquals(reducedDocument, document)) + { + break; + } + + document = reducedDocument; + } + + document = await ReduceGenericMethodTypeArgumentsAsync(document); + document = await ReduceThisQualificationAsync(document); + for (int i = 0; i < 4; i++) + { + var reducedDocument = await ReducePredefinedTypeNamesAsync(document); + if (ReferenceEquals(reducedDocument, document)) + { + break; + } + + document = reducedDocument; + } + document = await ReduceDocumentationQualifiedNamesAsync(document); // Reformat if any custom rewriters have been applied @@ -285,7 +308,7 @@ name.Parent is QualifiedNameSyntax || semanticModel.GetSymbolInfo(name).Symbol ?? semanticModel.GetTypeInfo(name).Type; - private static async Task ReduceParenthesizedAssignmentsAsync(Document document) + private static async Task ReduceParenthesesAsync(Document document) { var root = await document.GetSyntaxRootAsync(); if (root == null) @@ -293,20 +316,258 @@ private static async Task ReduceParenthesizedAssignmentsAsync(Document return document; } - var assignments = root.DescendantNodes() + var expressions = root.DescendantNodes() .OfType() - .Where(static node => - node.Expression is AssignmentExpressionSyntax && - node.Parent is ExpressionStatementSyntax) + .Where(CanRemoveParentheses) + .ToList(); + var patterns = root.DescendantNodes() + .OfType() .ToList(); - if (assignments.Count == 0) + if (expressions.Count == 0 && patterns.Count == 0) + { + return document; + } + + var rewrittenRoot = root + .ReplaceNodes( + expressions, + static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)) + .ReplaceNodes( + patterns, + static (_, rewritten) => rewritten.Pattern.WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static bool CanRemoveParentheses(ParenthesizedExpressionSyntax node) => node.Parent switch + { + ParenthesizedExpressionSyntax => true, + IfStatementSyntax ifStatement when ifStatement.Condition == node => true, + WhileStatementSyntax whileStatement when whileStatement.Condition == node => true, + DoStatementSyntax doStatement when doStatement.Condition == node => true, + ForStatementSyntax forStatement when forStatement.Condition == node => true, + SwitchStatementSyntax switchStatement when switchStatement.Expression == node => true, + ReturnStatementSyntax => true, + ArrowExpressionClauseSyntax => true, + EqualsValueClauseSyntax => true, + AssignmentExpressionSyntax assignment when assignment.Right == node => true, + ExpressionStatementSyntax when node.Expression is AssignmentExpressionSyntax => true, + ArgumentSyntax => node.Expression is not AssignmentExpressionSyntax and not ConditionalExpressionSyntax, + BracketedArgumentListSyntax => true, + ConditionalExpressionSyntax parent when parent.Condition == node => true, + WhenClauseSyntax whenClause when whenClause.Condition == node => true, + SwitchExpressionArmSyntax switchArm when switchArm.WhenClause?.Condition == node => true, + BinaryExpressionSyntax parent when parent.Left == node => + GetExpressionPrecedence(node.Expression) >= GetExpressionPrecedence(parent), + BinaryExpressionSyntax parent when parent.Right == node => + GetExpressionPrecedence(node.Expression) > GetExpressionPrecedence(parent) || + parent.IsKind(SyntaxKind.CoalesceExpression) && node.Expression.IsKind(SyntaxKind.CoalesceExpression), + PrefixUnaryExpressionSyntax => node.Expression is CastExpressionSyntax, + _ => false + }; + + private static int GetExpressionPrecedence(ExpressionSyntax expression) => expression.Kind() switch + { + SyntaxKind.SimpleMemberAccessExpression or + SyntaxKind.ElementAccessExpression or + SyntaxKind.InvocationExpression => 15, + SyntaxKind.CastExpression => 14, + SyntaxKind.UnaryMinusExpression or + SyntaxKind.UnaryPlusExpression or + SyntaxKind.LogicalNotExpression or + SyntaxKind.BitwiseNotExpression => 13, + SyntaxKind.MultiplyExpression or + SyntaxKind.DivideExpression or + SyntaxKind.ModuloExpression => 12, + SyntaxKind.AddExpression or + SyntaxKind.SubtractExpression => 11, + SyntaxKind.LeftShiftExpression or + SyntaxKind.RightShiftExpression => 10, + SyntaxKind.LessThanExpression or + SyntaxKind.LessThanOrEqualExpression or + SyntaxKind.GreaterThanExpression or + SyntaxKind.GreaterThanOrEqualExpression or + SyntaxKind.IsExpression or + SyntaxKind.AsExpression => 9, + SyntaxKind.EqualsExpression or + SyntaxKind.NotEqualsExpression => 8, + SyntaxKind.BitwiseAndExpression => 7, + SyntaxKind.ExclusiveOrExpression => 6, + SyntaxKind.BitwiseOrExpression => 5, + SyntaxKind.LogicalAndExpression => 4, + SyntaxKind.LogicalOrExpression => 3, + SyntaxKind.CoalesceExpression => 2, + SyntaxKind.SimpleAssignmentExpression or + SyntaxKind.AddAssignmentExpression or + SyntaxKind.SubtractAssignmentExpression or + SyntaxKind.MultiplyAssignmentExpression or + SyntaxKind.DivideAssignmentExpression or + SyntaxKind.ModuloAssignmentExpression or + SyntaxKind.AndAssignmentExpression or + SyntaxKind.ExclusiveOrAssignmentExpression or + SyntaxKind.OrAssignmentExpression or + SyntaxKind.LeftShiftAssignmentExpression or + SyntaxKind.RightShiftAssignmentExpression or + SyntaxKind.CoalesceAssignmentExpression => 1, + _ => 16 + }; + + private static async Task ReduceGenericMethodTypeArgumentsAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var replacements = new Dictionary(); + foreach (var genericName in root.DescendantNodes().OfType()) + { + if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || + memberAccess.Name != genericName || + memberAccess.Parent is not InvocationExpressionSyntax invocation || + invocation.Expression != memberAccess) + { + continue; + } + + var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; + if (originalSymbol == null) + { + continue; + } + + var candidateName = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); + var candidateInvocation = invocation.WithExpression(memberAccess.WithName(candidateName)); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacements.Add(genericName, candidateName); + } + } + + if (replacements.Count == 0) { return document; } var rewrittenRoot = root.ReplaceNodes( - assignments, - static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)); + replacements.Keys, + (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReduceThisQualificationAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var replacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (memberAccess.Expression is not ThisExpressionSyntax || + IsInUnsupportedQualifiedNameContext(memberAccess)) + { + continue; + } + + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? semanticModel.GetSymbolInfo(invocation).Symbol + : null; + if (originalSymbol == null) + { + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null) + { + continue; + } + } + + var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol == null && + memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null) + { + var candidateInvocation = parentInvocation.WithExpression(candidate); + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + parentInvocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + originalSymbol = originalInvocationSymbol; + } + + if (speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + replacements.Add(memberAccess, candidate); + } + } + + if (replacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + replacements.Keys, + (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReducePredefinedTypeNamesAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var identifiers = root.DescendantNodes() + .OfType() + .Where(static identifier => + identifier.Identifier.ValueText is "Byte" or "Char" && + identifier.Parent is not MemberAccessExpressionSyntax and not QualifiedNameSyntax) + .ToList(); + var castExpressions = root.DescendantNodes() + .OfType() + .Where(static castExpression => + castExpression.Expression is LiteralExpressionSyntax literalExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && + castExpression.Parent is EqualsValueClauseSyntax) + .ToList(); + + if (identifiers.Count == 0 && castExpressions.Count == 0) + { + return document; + } + + var rewrittenRoot = root + .ReplaceNodes( + identifiers, + static (_, rewritten) => rewritten.Identifier.ValueText switch + { + "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)).WithTriviaFrom(rewritten), + "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)).WithTriviaFrom(rewritten), + _ => rewritten + }) + .ReplaceNodes( + castExpressions, + static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)); return document.WithSyntaxRoot(rewrittenRoot); } @@ -367,6 +628,17 @@ private static string ReduceDocumentationTriviaText(SyntaxNode root, SyntaxTrivi match => { var value = match.Groups["value"].Value; + if (IsMockingReturnsCref(reduced, match.Index)) + { + value = value.Replace("SampleTypeSpec.Models.Custom.", "Models.Custom.", StringComparison.Ordinal); + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + } + + if (IsAbstractDerivedTypesCref(trivia, reduced, match.Index)) + { + return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; + } + foreach (var prefix in prefixes) { value = value.Replace(prefix, string.Empty, StringComparison.Ordinal); @@ -376,6 +648,33 @@ private static string ReduceDocumentationTriviaText(SyntaxNode root, SyntaxTrivi }); } + private static bool IsMockingReturnsCref(string text, int index) + { + var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); + var lineEnd = text.IndexOf('\n', index); + lineStart = lineStart < 0 ? 0 : lineStart + 1; + lineEnd = lineEnd < 0 ? text.Length : lineEnd; + return text.Substring(lineStart, lineEnd - lineStart).Contains(" instance for mocking.", StringComparison.Ordinal); + } + + private static bool IsAbstractDerivedTypesCref(SyntaxTrivia trivia, string text, int index) + { + if (trivia.Token.Parent? + .AncestorsAndSelf() + .OfType() + .FirstOrDefault()? + .Identifier.ValueText.EndsWith("ModelFactory", StringComparison.Ordinal) != true) + { + return false; + } + + var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); + var lineEnd = text.IndexOf('\n', index); + lineStart = lineStart < 0 ? 0 : lineStart + 1; + lineEnd = lineEnd < 0 ? text.Length : lineEnd; + return text.Substring(lineStart, lineEnd - lineStart).Contains("derived classes available for instantiation", StringComparison.Ordinal); + } + private static bool TryGetNameReplacement( SemanticModel semanticModel, NameSyntax originalName, @@ -554,11 +853,37 @@ private static bool TryGetMemberAccessReplacement( { replacement = memberAccess; var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var invocationExpression = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? invocation + : null; + var originalInvocationSymbol = invocationExpression != null + ? semanticModel.GetSymbolInfo(invocationExpression).Symbol + : null; + if (memberAccess.Expression is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.IntKeyword or (int)SyntaxKind.FloatKeyword } && + originalInvocationSymbol != null && + TryReduceInvocationExpression(semanticModel, invocationExpression!, memberAccess.Name, originalInvocationSymbol)) + { + replacement = memberAccess.Name.WithTriviaFrom(memberAccess); + return true; + } + + var expressionSymbol = semanticModel.GetSymbolInfo(memberAccess.Expression).Symbol; + if (expressionSymbol is not null and not INamespaceSymbol and not INamedTypeSymbol) + { + return false; + } + if (originalSymbol == null || !TryGetMemberAccessParts(memberAccess, out var parts) || parts.Count < 2) { - return false; + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null || + !TryGetMemberAccessParts(memberAccess, out parts) || + parts.Count < 2) + { + return false; + } } for (int i = parts.Count - 1; i > 0; i--) @@ -574,11 +899,35 @@ private static bool TryGetMemberAccessReplacement( replacement = candidate; return true; } + + if (memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null && + TryReduceInvocationExpression(semanticModel, parentInvocation, candidate, originalInvocationSymbol)) + { + replacement = candidate; + return true; + } } return false; } + private static bool TryReduceInvocationExpression( + SemanticModel semanticModel, + InvocationExpressionSyntax invocation, + ExpressionSyntax candidateExpression, + ISymbol originalInvocationSymbol) + { + var candidateInvocation = invocation.WithExpression(candidateExpression); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalInvocationSymbol, speculativeSymbol); + } + private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) { var builder = new List(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index d6773de36fb..55c0970cbdc 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -302,6 +302,8 @@ public async Task GetGeneratedFilesAsync_ReducesQualifiedXmlDocCrefs() namespace TestNamespace { /// See and . + /// The derived classes available for instantiation are: . + /// A new instance for mocking. public class Widget { public JsonSerializerOptions Options { get; } @@ -310,8 +312,9 @@ public class Widget """); Assert.That(generatedText, Does.Contain("")); + Assert.That(generatedText, Does.Contain("derived classes available for instantiation are: ")); Assert.That(generatedText, Does.Contain("")); - Assert.That(generatedText, Does.Not.Contain("TestNamespace.Widget")); + Assert.That(generatedText, Does.Contain(" instance for mocking")); Assert.That(generatedText, Does.Not.Contain("System.Text.Json.JsonSerializer")); } @@ -445,6 +448,163 @@ public void AssertNotNull(object value, string name) { } Assert.That(generatedText, Does.Contain("TestNamespace.Argument.AssertNotNull(value, nameof(value));")); } + [Test] + public async Task GetGeneratedFilesAsync_ReducesGeneratedParentheses() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +using System.Collections.Generic; + +namespace TestNamespace +{ + public class Container + { + private Dictionary _map; + private string[] _items; + private int _length; + + public Dictionary Map => (_map ??= new Dictionary()); + + public void Invoke(object writer, WidgetHolder widget, object value, object result, Byte[] bytes, Char[] chars, Options options = (Options)null, string nameHint = (String)null) + { + if (((value is ICollection collection) && (collection.Count == 0))) + { + Use(((Widget)result)); + } + + if ((_items[(collection.Count - 1)] == null)) + { + return; + } + + switch (collection.Count) + { + case ((>= 200) and (< 300)): + return; + } + + switch (value) + { + case string s when (s.Length > 0): + return; + } + + _map = _map ?? (GetMap() ?? new Dictionary()); + _length = (_length + 1); + string format = (nameHint == "W") ? nameHint : "J"; + string converted = TypeFormatters.ToString(bytes); + writer.WriteObjectValue((Widget)result, options); + widget.Value.Equals(widget.Value); + } + + private void Use(Widget widget) { } + private Dictionary GetMap() => null; + } + + public class Widget { } + public class WidgetHolder + { + public Widget Value { get; } + } + public class Options { } + + internal static class TypeFormatters + { + public static string ToString(byte[] value) => null; + public static string Invoke(byte[] value) => TypeFormatters.ToString(value); + } + + internal static class WriterExtensions + { + public static void WriteObjectValue(this object writer, T value, Options options) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Map => _map ??= new Dictionary();")); + Assert.That(generatedText, Does.Contain("if (value is ICollection collection && collection.Count == 0)")); + Assert.That(generatedText, Does.Contain("Use((Widget)result);")); + Assert.That(generatedText, Does.Contain("_items[collection.Count - 1] == null")); + Assert.That(generatedText, Does.Contain("case >= 200 and < 300:")); + Assert.That(generatedText, Does.Contain("byte[] bytes")); + Assert.That(generatedText, Does.Contain("char[] chars")); + Assert.That(generatedText, Does.Contain("Options options = null")); + Assert.That(generatedText, Does.Contain("string nameHint = null")); + Assert.That(generatedText, Does.Contain("case string s when s.Length > 0:")); + Assert.That(generatedText, Does.Contain("_map = _map ?? GetMap() ?? new Dictionary();")); + Assert.That(generatedText, Does.Contain("_length = _length + 1;")); + Assert.That(generatedText, Does.Contain("string format = nameHint == \"W\" ? nameHint : \"J\";")); + Assert.That(generatedText, Does.Contain("TypeFormatters.ToString(bytes);")); + Assert.That(generatedText, Does.Contain("writer.WriteObjectValue((Widget)result, options);")); + Assert.That(generatedText, Does.Contain("public static string Invoke(byte[] value) => ToString(value);")); + Assert.That(generatedText, Does.Contain("widget.Value.Equals(widget.Value);")); + Assert.That(generatedText, Does.Not.Contain("Use(((Widget)result))")); + Assert.That(generatedText, Does.Not.Contain("if ((value is ICollection collection)")); + Assert.That(generatedText, Does.Not.Contain("Byte[]")); + Assert.That(generatedText, Does.Not.Contain("Char[]")); + Assert.That(generatedText, Does.Not.Contain("(Options)null")); + Assert.That(generatedText, Does.Not.Contain("(String)null")); + } + + [Test] + public async Task GetGeneratedFilesAsync_ReducesThisQualificationSafely() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public string Name { get; } + + public void Invoke() + { + this.Create(this.Name); + } + + private void Create(string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Create(Name);")); + Assert.That(generatedText, Does.Not.Contain("this.Create")); + Assert.That(generatedText, Does.Not.Contain("this.Name")); + } + + [Test] + public async Task GetGeneratedFilesAsync_PreservesThisQualificationWhenLocalNameConflicts() + { + var generatedText = await ProcessGeneratedCodeAsync( + """ +// +#nullable disable + +namespace TestNamespace +{ + public class Container + { + public string Name { get; } + + public void Invoke(string Name) + { + this.Create(this.Name); + } + + private void Create(string name) { } + } +} +"""); + + Assert.That(generatedText, Does.Contain("Create(this.Name);")); + } + [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { From 6c536edbc9748180382839525f7358d6c24375bb Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 04:08:35 +0000 Subject: [PATCH 10/16] Add post-processing benchmark Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../perf/PostProcessingBenchmark.cs | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs new file mode 100644 index 00000000000..61cd6d615b2 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.CodeAnalysis; +using Microsoft.TypeSpec.Generator.Primitives; + +namespace Microsoft.TypeSpec.Generator.Perf +{ + public class PostProcessingBenchmark + { + private (string Name, string Content)[] _generatedFiles = []; + + [GlobalSetup] + public void GlobalSetup() + { + InitializeGenerator(); + + var generatedDirectory = FindSampleTypeSpecGeneratedDirectory(); + _generatedFiles = Directory.GetFiles(generatedDirectory, "*.cs", SearchOption.AllDirectories) + .OrderBy(static path => path, StringComparer.Ordinal) + .Select(path => (Name: Path.GetRelativePath(generatedDirectory, path), Content: File.ReadAllText(path))) + .ToArray(); + + if (_generatedFiles.Length == 0) + { + throw new InvalidOperationException($"No generated C# files found under '{generatedDirectory}'."); + } + } + + [Benchmark] + public async Task ProcessSampleTypeSpecGeneratedFiles() + { + GeneratedCodeWorkspace.Initialize(); + var workspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); + + foreach (var file in _generatedFiles) + { + await workspace.AddGeneratedFile(new CodeFile(file.Content, file.Name)); + } + + var totalLength = 0; + await foreach (var file in workspace.GetGeneratedFilesAsync()) + { + totalLength += file.Text.Length; + } + + return totalLength; + } + + private static string FindSampleTypeSpecGeneratedDirectory() + { + const string relativePath = "packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated"; + + var directory = new DirectoryInfo(AppContext.BaseDirectory); + while (directory != null) + { + var generatedDirectory = Path.Combine(directory.FullName, relativePath); + if (Directory.Exists(generatedDirectory)) + { + return generatedDirectory; + } + + directory = directory.Parent; + } + + throw new DirectoryNotFoundException($"Could not find '{relativePath}' from '{AppContext.BaseDirectory}'."); + } + + private static void InitializeGenerator() + { + var outputPath = Path.Combine(AppContext.BaseDirectory, "PostProcessingBenchmark"); + Directory.CreateDirectory(outputPath); + + var generator = new BenchmarkCodeModelGenerator(outputPath); + foreach (var referencePath in GetMetadataReferencePaths()) + { + generator.AddMetadataReference(MetadataReference.CreateFromFile(referencePath)); + } + + CodeModelGenerator.Instance = generator; + } + + private static IEnumerable GetMetadataReferencePaths() + { + HashSet referencePaths = new(StringComparer.OrdinalIgnoreCase); + + if (AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES") is string trustedPlatformAssemblies) + { + foreach (var referencePath in trustedPlatformAssemblies.Split(Path.PathSeparator)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + foreach (var referencePath in Directory.GetFiles(AppContext.BaseDirectory, "*.dll", SearchOption.TopDirectoryOnly)) + { + if (referencePaths.Add(referencePath)) + { + yield return referencePath; + } + } + } + + private sealed class BenchmarkCodeModelGenerator : CodeModelGenerator + { + public BenchmarkCodeModelGenerator(string outputPath) + : base(new GeneratorContext(Configuration.Load(outputPath, "{\"package-name\":\"Sample.TypeSpec\",\"disable-xml-docs\":false}"))) + { + } + } + } +} From 90def0ad10d8e04ae73be8cd4e1ec7ba8ad5408d Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 4 Jun 2026 04:32:37 +0000 Subject: [PATCH 11/16] Support scaled post-processing benchmark corpora Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../perf/PostProcessingBenchmark.cs | 91 ++++++++++++++++++- 1 file changed, 86 insertions(+), 5 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs index 61cd6d615b2..f96a0d5053b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Text.RegularExpressions; using System.Threading.Tasks; using BenchmarkDotNet.Attributes; using Microsoft.CodeAnalysis; @@ -14,6 +15,14 @@ namespace Microsoft.TypeSpec.Generator.Perf { public class PostProcessingBenchmark { + private const string GeneratedDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_GENERATED_DIR"; + private static readonly Regex NamespaceDeclarationRegex = new( + @"\bnamespace\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)", + RegexOptions.Compiled); + + [Params(1, 5)] + public int CorpusMultiplier { get; set; } + private (string Name, string Content)[] _generatedFiles = []; [GlobalSetup] @@ -21,16 +30,18 @@ public void GlobalSetup() { InitializeGenerator(); - var generatedDirectory = FindSampleTypeSpecGeneratedDirectory(); - _generatedFiles = Directory.GetFiles(generatedDirectory, "*.cs", SearchOption.AllDirectories) + var generatedDirectory = FindGeneratedDirectory(); + var sourceFiles = Directory.GetFiles(generatedDirectory, "*.cs", SearchOption.AllDirectories) .OrderBy(static path => path, StringComparer.Ordinal) - .Select(path => (Name: Path.GetRelativePath(generatedDirectory, path), Content: File.ReadAllText(path))) .ToArray(); - if (_generatedFiles.Length == 0) + if (sourceFiles.Length == 0) { throw new InvalidOperationException($"No generated C# files found under '{generatedDirectory}'."); } + + var declaredNamespaces = GetDeclaredNamespaces(sourceFiles); + _generatedFiles = BuildCorpus(generatedDirectory, sourceFiles, declaredNamespaces); } [Benchmark] @@ -53,8 +64,78 @@ public async Task ProcessSampleTypeSpecGeneratedFiles() return totalLength; } - private static string FindSampleTypeSpecGeneratedDirectory() + private (string Name, string Content)[] BuildCorpus(string generatedDirectory, string[] sourceFiles, IReadOnlyList declaredNamespaces) + { + var generatedFiles = new List<(string Name, string Content)>(sourceFiles.Length * CorpusMultiplier); + for (var i = 0; i < CorpusMultiplier; i++) + { + var namespaceSuffix = CorpusMultiplier == 1 ? string.Empty : $".BenchmarkCopy{i}"; + var folderPrefix = CorpusMultiplier == 1 ? string.Empty : $"BenchmarkCopy{i}"; + foreach (var path in sourceFiles) + { + var relativePath = Path.GetRelativePath(generatedDirectory, path); + var content = File.ReadAllText(path); + if (CorpusMultiplier > 1) + { + content = MakeNamespacesUnique(content, declaredNamespaces, namespaceSuffix); + } + + generatedFiles.Add((Path.Combine(folderPrefix, relativePath), content)); + } + } + + return generatedFiles.ToArray(); + } + + private static IReadOnlyList GetDeclaredNamespaces(string[] sourceFiles) + { + var declaredNamespaces = sourceFiles + .SelectMany(static path => NamespaceDeclarationRegex.Matches(File.ReadAllText(path))) + .Select(static match => match.Groups[1].Value) + .Distinct(StringComparer.Ordinal) + .ToArray(); + + return declaredNamespaces + .Where(ns => !declaredNamespaces.Any(candidate => + !string.Equals(ns, candidate, StringComparison.Ordinal) && + ns.StartsWith(candidate + ".", StringComparison.Ordinal))) + .OrderByDescending(static ns => ns.Length) + .ToArray(); + } + + private static string MakeNamespacesUnique(string content, IReadOnlyList declaredNamespaces, string namespaceSuffix) + { + foreach (var declaredNamespace in declaredNamespaces) + { + var escapedNamespace = Regex.Escape(declaredNamespace); + content = content.Replace($"global::{declaredNamespace}.", $"global::{declaredNamespace}{namespaceSuffix}.", StringComparison.Ordinal); + content = Regex.Replace( + content, + $@"(? Date: Thu, 4 Jun 2026 05:04:40 +0000 Subject: [PATCH 12/16] Reduce manual post-processing passes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../PostProcessing/GeneratedCodeWorkspace.cs | 300 ++++++++++++++---- .../test/GeneratedCodeWorkspaceTests.cs | 3 +- 2 files changed, 234 insertions(+), 69 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 9aa35274248..38ac362d741 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -162,30 +162,8 @@ private async Task ProcessDocument(Document document, MemberRemoverRew document = reducedDocument; } - for (int i = 0; i < 8; i++) - { - var reducedDocument = await ReduceParenthesesAsync(document); - if (ReferenceEquals(reducedDocument, document)) - { - break; - } - - document = reducedDocument; - } - - document = await ReduceGenericMethodTypeArgumentsAsync(document); - document = await ReduceThisQualificationAsync(document); - for (int i = 0; i < 4; i++) - { - var reducedDocument = await ReducePredefinedTypeNamesAsync(document); - if (ReferenceEquals(reducedDocument, document)) - { - break; - } - - document = reducedDocument; - } - + document = await ReduceSemanticOnlyAsync(document); + document = await ReduceSyntaxOnlyAsync(document); document = await ReduceDocumentationQualifiedNamesAsync(document); // Reformat if any custom rewriters have been applied @@ -205,45 +183,6 @@ private static async Task ReduceQualifiedNamesAsync(Document document) return document; } - var globalAliases = new Dictionary(); - foreach (var aliasName in root.DescendantNodes().OfType()) - { - if (aliasName.Alias.Identifier.ValueText != "global" || - aliasName.Ancestors().Any(static ancestor => - ancestor is AttributeSyntax) || - IsInMemberAccessExpressionChain(aliasName) || - IsInUnsupportedQualifiedNameContext(aliasName)) - { - continue; - } - - var originalSymbol = GetSymbol(semanticModel, aliasName); - if (originalSymbol == null) - { - continue; - } - - var replacement = aliasName.Name.WithTriviaFrom(aliasName); - if (SpeculativelyBindsToSameSymbol(semanticModel, aliasName, replacement, originalSymbol)) - { - globalAliases.Add(aliasName, replacement); - } - } - - if (globalAliases.Count > 0) - { - root = root.ReplaceNodes( - globalAliases.Keys, - (original, rewritten) => globalAliases[original].WithTriviaFrom(rewritten)); - document = document.WithSyntaxRoot(root); - semanticModel = await document.GetSemanticModelAsync(); - root = await document.GetSyntaxRootAsync(); - if (root == null || semanticModel == null) - { - return document; - } - } - var safeNameReplacements = new Dictionary(); foreach (var name in root.DescendantNodes().OfType()) { @@ -294,7 +233,8 @@ name.Parent is QualifiedNameSyntax || } var rewrittenRoot = root.ReplaceNodes( - safeNameReplacements.Keys.Concat(safeMemberAccessReplacements.Keys), + safeNameReplacements.Keys + .Concat(safeMemberAccessReplacements.Keys), (original, rewritten) => original switch { NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), @@ -304,6 +244,153 @@ name.Parent is QualifiedNameSyntax || return document.WithSyntaxRoot(rewrittenRoot); } + private static async Task ReduceSemanticOnlyAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + var semanticModel = await document.GetSemanticModelAsync(); + if (root == null || semanticModel == null) + { + return document; + } + + var memberAccessReplacements = new Dictionary(); + foreach (var memberAccess in root.DescendantNodes().OfType()) + { + if (!IsInUnsupportedQualifiedNameContext(memberAccess) && + TryGetThisQualificationReplacement(semanticModel, memberAccess, out var replacement)) + { + memberAccessReplacements.Add(memberAccess, replacement); + } + } + + var genericNameReplacements = new Dictionary(); + foreach (var genericName in root.DescendantNodes().OfType()) + { + if (TryGetGenericMethodTypeArgumentsReplacement(semanticModel, genericName, out var replacement)) + { + genericNameReplacements.Add(genericName, replacement); + } + } + + if (memberAccessReplacements.Count == 0 && genericNameReplacements.Count == 0) + { + return document; + } + + var rewrittenRoot = root.ReplaceNodes( + memberAccessReplacements.Keys.Concat(genericNameReplacements.Keys), + (original, rewritten) => original switch + { + GenericNameSyntax genericName => genericNameReplacements[genericName].WithTriviaFrom(rewritten), + MemberAccessExpressionSyntax memberAccess => memberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), + _ => rewritten + }); + return document.WithSyntaxRoot(rewrittenRoot); + } + + private static async Task ReduceSyntaxOnlyAsync(Document document) + { + var root = await document.GetSyntaxRootAsync(); + if (root == null) + { + return document; + } + + var rewriter = new SyntaxOnlyReducer(); + var rewrittenRoot = rewriter.Visit(root); + return rewriter.Changed && rewrittenRoot != null + ? document.WithSyntaxRoot(rewrittenRoot) + : document; + } + + private sealed class SyntaxOnlyReducer : CSharpSyntaxRewriter + { + public bool Changed { get; private set; } + + public override SyntaxNode? VisitParenthesizedExpression(ParenthesizedExpressionSyntax node) + { + var rewritten = (ParenthesizedExpressionSyntax)base.VisitParenthesizedExpression(node)!; + if (CanRemoveParentheses(node)) + { + Changed = true; + return rewritten.Expression.WithTriviaFrom(rewritten); + } + + return rewritten; + } + + public override SyntaxNode? VisitParenthesizedPattern(ParenthesizedPatternSyntax node) + { + var rewritten = (ParenthesizedPatternSyntax)base.VisitParenthesizedPattern(node)!; + Changed = true; + return rewritten.Pattern.WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) + { + var rewritten = (IdentifierNameSyntax)base.VisitIdentifierName(node)!; + if (rewritten.Identifier.ValueText is not ("Byte" or "Char" or "String") || + node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) + { + return rewritten; + } + + Changed = true; + return GetPredefinedType(rewritten.Identifier.ValueText).WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node) + { + var rewritten = (QualifiedNameSyntax)base.VisitQualifiedName(node)!; + if (rewritten.Left is not IdentifierNameSyntax { Identifier.ValueText: "System" } || + rewritten.Right.Identifier.ValueText is not ("Byte" or "Char" or "String") || + node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) + { + return rewritten; + } + + Changed = true; + return GetPredefinedType(rewritten.Right.Identifier.ValueText).WithTriviaFrom(rewritten); + } + + public override SyntaxNode? VisitCastExpression(CastExpressionSyntax node) + { + var rewritten = (CastExpressionSyntax)base.VisitCastExpression(node)!; + if (rewritten.Expression is LiteralExpressionSyntax literalExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && + (rewritten.Parent is EqualsValueClauseSyntax || node.Parent is EqualsValueClauseSyntax)) + { + Changed = true; + return rewritten.Expression.WithTriviaFrom(rewritten); + } + + return rewritten; + } + + public override SyntaxNode? VisitEqualsValueClause(EqualsValueClauseSyntax node) + { + var rewritten = (EqualsValueClauseSyntax)base.VisitEqualsValueClause(node)!; + if (rewritten.Value is CastExpressionSyntax { Expression: LiteralExpressionSyntax literalExpression } castExpression && + (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || + literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression))) + { + Changed = true; + return rewritten.WithValue(castExpression.Expression.WithTriviaFrom(castExpression)); + } + + return rewritten; + } + + private static PredefinedTypeSyntax GetPredefinedType(string typeName) => typeName switch + { + "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)), + "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)), + "String" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)), + _ => throw new InvalidOperationException($"Unexpected predefined type name: {typeName}") + }; + } + private static ISymbol? GetSymbol(SemanticModel semanticModel, NameSyntax name) => semanticModel.GetSymbolInfo(name).Symbol ?? semanticModel.GetTypeInfo(name).Type; @@ -461,6 +548,35 @@ memberAccess.Parent is not InvocationExpressionSyntax invocation || return document.WithSyntaxRoot(rewrittenRoot); } + private static bool TryGetGenericMethodTypeArgumentsReplacement( + SemanticModel semanticModel, + GenericNameSyntax genericName, + out IdentifierNameSyntax replacement) + { + replacement = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); + if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || + memberAccess.Name != genericName || + memberAccess.Parent is not InvocationExpressionSyntax invocation || + invocation.Expression != memberAccess) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; + if (originalSymbol == null) + { + return false; + } + + var candidateInvocation = invocation.WithExpression(memberAccess.WithName(replacement)); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + invocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + return speculativeSymbol != null && + SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); + } + private static async Task ReduceThisQualificationAsync(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -528,6 +644,58 @@ memberAccess.Parent is InvocationExpressionSyntax parentInvocation && return document.WithSyntaxRoot(rewrittenRoot); } + private static bool TryGetThisQualificationReplacement( + SemanticModel semanticModel, + MemberAccessExpressionSyntax memberAccess, + out ExpressionSyntax replacement) + { + replacement = memberAccess; + if (memberAccess.Expression is not ThisExpressionSyntax) + { + return false; + } + + var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; + var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess + ? semanticModel.GetSymbolInfo(invocation).Symbol + : null; + if (originalSymbol == null) + { + originalSymbol = originalInvocationSymbol; + if (originalSymbol == null) + { + return false; + } + } + + var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); + var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + memberAccess.SpanStart, + candidate, + SpeculativeBindingOption.BindAsExpression).Symbol; + if (speculativeSymbol == null && + memberAccess.Parent is InvocationExpressionSyntax parentInvocation && + parentInvocation.Expression == memberAccess && + originalInvocationSymbol != null) + { + var candidateInvocation = parentInvocation.WithExpression(candidate); + speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( + parentInvocation.SpanStart, + candidateInvocation, + SpeculativeBindingOption.BindAsExpression).Symbol; + originalSymbol = originalInvocationSymbol; + } + + if (speculativeSymbol == null || + !SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + { + return false; + } + + replacement = candidate; + return true; + } + private static async Task ReducePredefinedTypeNamesAsync(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -995,10 +1163,6 @@ private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax express expression.Ancestors().Any(static ancestor => ancestor is CrefSyntax); - private static bool IsInMemberAccessExpressionChain(NameSyntax name) => - name.Ancestors().Any(static ancestor => ancestor is MemberAccessExpressionSyntax) && - !name.Ancestors().Any(static ancestor => ancestor is TypeSyntax); - private static bool ContainsSimplifierAnnotations(SyntaxNode root) => root.HasAnnotation(Simplifier.Annotation) || root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index 55c0970cbdc..d1aec2a8486 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -468,7 +468,7 @@ public class Container public Dictionary Map => (_map ??= new Dictionary()); - public void Invoke(object writer, WidgetHolder widget, object value, object result, Byte[] bytes, Char[] chars, Options options = (Options)null, string nameHint = (String)null) + public void Invoke(object writer, WidgetHolder widget, object value, object result, Byte[] bytes, Char[] chars, Options options = (Options)null, string nameHint = (String)null, global::System.String title = (global::System.String)null) { if (((value is ICollection collection) && (collection.Count == 0))) { @@ -533,6 +533,7 @@ public static void WriteObjectValue(this object writer, T value, Options opti Assert.That(generatedText, Does.Contain("char[] chars")); Assert.That(generatedText, Does.Contain("Options options = null")); Assert.That(generatedText, Does.Contain("string nameHint = null")); + Assert.That(generatedText, Does.Contain("string title = null")); Assert.That(generatedText, Does.Contain("case string s when s.Length > 0:")); Assert.That(generatedText, Does.Contain("_map = _map ?? GetMap() ?? new Dictionary();")); Assert.That(generatedText, Does.Contain("_length = _length + 1;")); From 8f6da6a99728897853c658748141afd71a6028c7 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 11 Jun 2026 14:50:14 +0000 Subject: [PATCH 13/16] Profile C# post-processing benchmarks --- .../perf/FullGenerationBenchmark.cs | 157 +++ .../perf/PostProcessingBenchmark.cs | 60 +- .../src/CSharpGen.cs | 162 ++- .../PostProcessing/GeneratedCodeWorkspace.cs | 1086 +---------------- ...ratedCodeWorkspacePostProcessingProfile.cs | 64 + .../src/PostProcessing/PostProcessor.cs | 186 ++- .../test/GeneratedCodeWorkspaceTests.cs | 530 -------- 7 files changed, 592 insertions(+), 1653 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs new file mode 100644 index 00000000000..66553e08c38 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; + +namespace Microsoft.TypeSpec.Generator.Perf +{ + public class FullGenerationBenchmark + { + private const string ProfileEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_STEPS"; + private const string ProfileOutputDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_DIR"; + + private bool _profileSteps; + + [GlobalSetup] + public void GlobalSetup() + { + _profileSteps = string.Equals( + Environment.GetEnvironmentVariable(ProfileEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + } + + [Benchmark] + public async Task GenerateSampleTypeSpecProject() + { + var postProcessingProfile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + var generationProfile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + GeneratedCodeWorkspace.PostProcessingProfile = postProcessingProfile; + CSharpGen.GenerationProfile = generationProfile; + + var benchmarkDirectory = CreateBenchmarkInputDirectory(); + var stopwatch = Stopwatch.StartNew(); + try + { + CodeModelGenerator.Instance = new BenchmarkCodeModelGenerator(benchmarkDirectory); + CodeModelGenerator.Instance.Configure(); + + var csharpGen = new CSharpGen(); + await csharpGen.ExecuteAsync(); + + return Directory.GetFiles(benchmarkDirectory, "*", SearchOption.AllDirectories) + .Where(static path => !path.EndsWith("tspCodeModel.json", StringComparison.Ordinal) && + !path.EndsWith("Configuration.json", StringComparison.Ordinal)) + .Sum(static path => (int)new FileInfo(path).Length); + } + finally + { + stopwatch.Stop(); + if (generationProfile != null) + { + WriteProfile( + generationProfile, + $"full-generation-profile-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Full generation invocation elapsed ms: {stopwatch.Elapsed.TotalMilliseconds:F3}{Environment.NewLine}" + + $"Input directory: {benchmarkDirectory}{Environment.NewLine}"); + } + + if (postProcessingProfile != null) + { + WriteProfile( + postProcessingProfile, + $"full-generation-post-processing-profile-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Full generation post-processing profile{Environment.NewLine}" + + $"Input directory: {benchmarkDirectory}{Environment.NewLine}"); + } + + CSharpGen.GenerationProfile = null; + GeneratedCodeWorkspace.PostProcessingProfile = null; + TryDeleteDirectory(benchmarkDirectory); + } + } + + private static void WriteProfile(GeneratedCodeWorkspacePostProcessingProfile profile, string fileName, string header) + { + var profileDirectory = GetProfileOutputDirectory(); + Directory.CreateDirectory(profileDirectory); + File.WriteAllText(Path.Combine(profileDirectory, fileName), header + profile.GetSummary()); + } + + private static string CreateBenchmarkInputDirectory() + { + var sourceDirectory = FindFullGenerationInputDirectory(); + var benchmarkDirectory = Path.Combine(Path.GetTempPath(), "typespec-full-generation-benchmark", Guid.NewGuid().ToString("N")); + CopyDirectory(sourceDirectory, benchmarkDirectory); + return benchmarkDirectory; + } + + private static string FindFullGenerationInputDirectory() + { + const string relativePath = "packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec"; + + var directory = new DirectoryInfo(AppContext.BaseDirectory); + while (directory != null) + { + var inputDirectory = Path.Combine(directory.FullName, relativePath); + if (File.Exists(Path.Combine(inputDirectory, "tspCodeModel.json")) && + File.Exists(Path.Combine(inputDirectory, "Configuration.json"))) + { + return inputDirectory; + } + + directory = directory.Parent; + } + + throw new DirectoryNotFoundException($"Could not find '{relativePath}' from '{AppContext.BaseDirectory}'."); + } + + private static void CopyDirectory(string sourceDirectory, string destinationDirectory) + { + Directory.CreateDirectory(destinationDirectory); + foreach (var sourceFile in Directory.GetFiles(sourceDirectory, "*", SearchOption.AllDirectories)) + { + var relativePath = Path.GetRelativePath(sourceDirectory, sourceFile); + var destinationFile = Path.Combine(destinationDirectory, relativePath); + Directory.CreateDirectory(Path.GetDirectoryName(destinationFile)!); + File.Copy(sourceFile, destinationFile, overwrite: true); + } + } + + private static void TryDeleteDirectory(string directory) + { + try + { + if (Directory.Exists(directory)) + { + Directory.Delete(directory, recursive: true); + } + } + catch + { + // Best-effort cleanup for benchmark temp output. + } + } + + private static string GetProfileOutputDirectory() + { + var configuredPath = Environment.GetEnvironmentVariable(ProfileOutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(configuredPath) + ? Path.Combine(Path.GetTempPath(), "typespec-post-processing-profiles") + : Path.GetFullPath(configuredPath); + } + + private sealed class BenchmarkCodeModelGenerator : CodeModelGenerator + { + public BenchmarkCodeModelGenerator(string outputPath) + : base(new GeneratorContext(Configuration.Load(outputPath))) + { + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs index f96a0d5053b..24b5394e2cd 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/PostProcessingBenchmark.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Text.RegularExpressions; @@ -16,6 +17,8 @@ namespace Microsoft.TypeSpec.Generator.Perf public class PostProcessingBenchmark { private const string GeneratedDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_GENERATED_DIR"; + private const string ProfileEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_STEPS"; + private const string ProfileOutputDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_DIR"; private static readonly Regex NamespaceDeclarationRegex = new( @"\bnamespace\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)", RegexOptions.Compiled); @@ -24,6 +27,7 @@ public class PostProcessingBenchmark public int CorpusMultiplier { get; set; } private (string Name, string Content)[] _generatedFiles = []; + private bool _profileSteps; [GlobalSetup] public void GlobalSetup() @@ -42,26 +46,68 @@ public void GlobalSetup() var declaredNamespaces = GetDeclaredNamespaces(sourceFiles); _generatedFiles = BuildCorpus(generatedDirectory, sourceFiles, declaredNamespaces); + _profileSteps = string.Equals( + Environment.GetEnvironmentVariable(ProfileEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); } [Benchmark] public async Task ProcessSampleTypeSpecGeneratedFiles() { + var profile = _profileSteps ? new GeneratedCodeWorkspacePostProcessingProfile() : null; + GeneratedCodeWorkspace.PostProcessingProfile = profile; + + var stopwatch = Stopwatch.StartNew(); GeneratedCodeWorkspace.Initialize(); var workspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); - foreach (var file in _generatedFiles) + try { - await workspace.AddGeneratedFile(new CodeFile(file.Content, file.Name)); - } + foreach (var file in _generatedFiles) + { + await workspace.AddGeneratedFile(new CodeFile(file.Content, file.Name)); + } - var totalLength = 0; - await foreach (var file in workspace.GetGeneratedFilesAsync()) + var totalLength = 0; + await foreach (var file in workspace.GetGeneratedFilesAsync()) + { + totalLength += file.Text.Length; + } + + return totalLength; + } + finally { - totalLength += file.Text.Length; + stopwatch.Stop(); + if (profile != null) + { + WriteProfile( + profile, + $"post-processing-profile-RoslynSimplifier-x{CorpusMultiplier}-{DateTime.UtcNow:yyyyMMddHHmmssfff}.csv", + $"Reduction strategy: RoslynSimplifier{Environment.NewLine}" + + $"Corpus multiplier: {CorpusMultiplier}{Environment.NewLine}" + + $"File count: {_generatedFiles.Length}{Environment.NewLine}" + + $"Benchmark invocation elapsed ms: {stopwatch.Elapsed.TotalMilliseconds:F3}{Environment.NewLine}"); + } + + GeneratedCodeWorkspace.PostProcessingProfile = null; } + } - return totalLength; + private static void WriteProfile(GeneratedCodeWorkspacePostProcessingProfile profile, string fileName, string header) + { + var profileDirectory = GetProfileOutputDirectory(); + Directory.CreateDirectory(profileDirectory); + File.WriteAllText(Path.Combine(profileDirectory, fileName), header + profile.GetSummary()); + } + + private static string GetProfileOutputDirectory() + { + var configuredPath = Environment.GetEnvironmentVariable(ProfileOutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(configuredPath) + ? Path.Combine(Path.GetTempPath(), "typespec-post-processing-profiles") + : Path.GetFullPath(configuredPath); } private (string Name, string Content)[] BuildCorpus(string generatedDirectory, string[] sourceFiles, IReadOnlyList declaredNamespaces) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index 9948fcff594..238b16b9766 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Threading.Tasks; @@ -20,6 +21,8 @@ internal sealed class CSharpGen private static readonly string[] _filesToKeep = [ConfigurationFileName, CodeModelFileName]; + internal static GeneratedCodeWorkspacePostProcessingProfile? GenerationProfile { get; set; } + /// /// Executes the generator task with the instance. /// @@ -33,19 +36,21 @@ public async Task ExecuteAsync() // Resolve PackageReference items from the .csproj so custom code referencing // external NuGet types (e.g., Azure.Storage.Common) compiles correctly. - await GeneratedCodeWorkspace.AddPackageReferencesFromProject(); + await MeasureGenerationStepAsync("Generation.AddPackageReferencesFromProject", GeneratedCodeWorkspace.AddPackageReferencesFromProject); // Pre-walk the input library and resolve any external types that point at NuGet packages. // This populates ExternalTypeReferenceResolver's cache and registers each resolved assembly // as an additional metadata reference *before* the generated/custom code workspaces are // constructed, so their cached Roslyn projects pick the references up. - await ExternalTypeReferenceResolver.ResolveAllAsync(); + await MeasureGenerationStepAsync("Generation.ResolveExternalTypeReferences", ExternalTypeReferenceResolver.ResolveAllAsync); // Initialize the workspace project AFTER all metadata references have been added so the // eagerly-cached project sees them. GeneratedCodeWorkspace.Initialize(); - GeneratedCodeWorkspace customCodeWorkspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: true); + GeneratedCodeWorkspace customCodeWorkspace = await MeasureGenerationStepAsync( + "Generation.CreateCustomCodeWorkspace", + () => GeneratedCodeWorkspace.Create(isCustomCodeProject: true)); // The generated attributes need to be added into the workspace before loading the custom code. Otherwise, // Roslyn doesn't load the attributes completely and we are unable to get the attribute arguments. @@ -55,88 +60,167 @@ public async Task ExecuteAsync() generateAttributeTasks.Add(customCodeWorkspace.AddInMemoryFile(attributeProvider)); } - await Task.WhenAll(generateAttributeTasks); + await MeasureGenerationStepAsync("Generation.AddCustomizationAttributeProviders", () => Task.WhenAll(generateAttributeTasks)); - CodeModelGenerator.Instance.SourceInputModel = new SourceInputModel( - await customCodeWorkspace.GetCompilationAsync(), - await GeneratedCodeWorkspace.LoadBaselineContract()); + CodeModelGenerator.Instance.SourceInputModel = await MeasureGenerationStepAsync( + "Generation.CreateSourceInputModel", + async () => new SourceInputModel( + await customCodeWorkspace.GetCompilationAsync(), + await GeneratedCodeWorkspace.LoadBaselineContract())); - GeneratedCodeWorkspace generatedCodeWorkspace = await GeneratedCodeWorkspace.Create(isCustomCodeProject: false); + GeneratedCodeWorkspace generatedCodeWorkspace = await MeasureGenerationStepAsync( + "Generation.CreateGeneratedCodeWorkspace", + () => GeneratedCodeWorkspace.Create(isCustomCodeProject: false)); - var output = CodeModelGenerator.Instance.OutputLibrary; + var output = MeasureGenerationStep("Generation.GetOutputLibrary", () => CodeModelGenerator.Instance.OutputLibrary); Directory.CreateDirectory(Path.Combine(generatedSourceOutputPath, "Models")); List generateFilesTasks = new(); // Build all TypeProviders - foreach (var type in output.TypeProviders) + MeasureGenerationStep("Generation.BuildTypeProviders", () => { - type.EnsureBuilt(); - } + foreach (var type in output.TypeProviders) + { + type.EnsureBuilt(); + } + }); LoggingHelpers.LogElapsedTime("All generated type providers built"); // visit the entire library before generating files - foreach (var visitor in CodeModelGenerator.Instance.Visitors) + MeasureGenerationStep("Generation.ApplyVisitors", () => { - visitor.VisitLibrary(output); - } + foreach (var visitor in CodeModelGenerator.Instance.Visitors) + { + visitor.VisitLibrary(output); + } + }); - FilterAllCustomizedMembers(output); + MeasureGenerationStep("Generation.FilterCustomizedMembers", () => FilterAllCustomizedMembers(output)); LoggingHelpers.LogElapsedTime("All visitors have been applied"); - foreach (var outputType in output.TypeProviders) + MeasureGenerationStep("Generation.WriteTypeProviders", () => { - // Ensure back-compatibility processing is done after all visitors have run - outputType.ProcessTypeForBackCompatibility(); - - var writer = CodeModelGenerator.Instance.GetWriter(outputType); - generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); - - foreach (var serialization in outputType.SerializationProviders) + foreach (var outputType in output.TypeProviders) { - writer = CodeModelGenerator.Instance.GetWriter(serialization); + // Ensure back-compatibility processing is done after all visitors have run + outputType.ProcessTypeForBackCompatibility(); + + var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + + foreach (var serialization in outputType.SerializationProviders) + { + writer = CodeModelGenerator.Instance.GetWriter(serialization); + generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + } } - } + }); // Add all the generated files to the workspace - await Task.WhenAll(generateFilesTasks); + await MeasureGenerationStepAsync("Generation.AddGeneratedFilesToWorkspace", () => Task.WhenAll(generateFilesTasks)); LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files - DeleteDirectory(generatedSourceOutputPath, _filesToKeep); + MeasureGenerationStep("Generation.DeleteOldGeneratedFiles", () => DeleteDirectory(generatedSourceOutputPath, _filesToKeep)); LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); - await generatedCodeWorkspace.PostProcessAsync(); + await MeasureGenerationStepAsync("Generation.PostProcessAsync", generatedCodeWorkspace.PostProcessAsync); // Write the generated files to the output directory - await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) + await MeasureGenerationStepAsync("Generation.WriteGeneratedFilesToDisk", async () => { - if (string.IsNullOrEmpty(file.Text)) + await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) { - continue; + if (string.IsNullOrEmpty(file.Text)) + { + continue; + } + var filename = Path.Combine(outputPath, file.Name); + CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); + Directory.CreateDirectory(Path.GetDirectoryName(filename)!); + await File.WriteAllTextAsync(filename, file.Text); } - var filename = Path.Combine(outputPath, file.Name); - CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); - Directory.CreateDirectory(Path.GetDirectoryName(filename)!); - await File.WriteAllTextAsync(filename, file.Text); - } + }); // Write additional output files (e.g. configuration schemas, .targets files) - await CodeModelGenerator.Instance.WriteAdditionalFiles(outputPath); + await MeasureGenerationStepAsync("Generation.WriteAdditionalFiles", () => CodeModelGenerator.Instance.WriteAdditionalFiles(outputPath)); // Write project scaffolding files (after additional files so schema existence can be checked) if (CodeModelGenerator.Instance.IsNewProject) { - await CodeModelGenerator.Instance.TypeFactory.CreateNewProjectScaffolding().Execute(); + await MeasureGenerationStepAsync( + "Generation.WriteProjectScaffolding", + () => CodeModelGenerator.Instance.TypeFactory.CreateNewProjectScaffolding().Execute()); } LoggingHelpers.LogElapsedTime("All files have been written to disk"); } + private static void MeasureGenerationStep(string stepName, Action action) + { + MeasureGenerationStep( + stepName, + () => + { + action(); + return 0; + }); + } + + private static T MeasureGenerationStep(string stepName, Func action) + { + var profile = GenerationProfile; + if (profile == null) + { + return action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + + private static Task MeasureGenerationStepAsync(string stepName, Func action) => MeasureGenerationStepAsync( + stepName, + async () => + { + await action(); + return 0; + }); + + private static async Task MeasureGenerationStepAsync(string stepName, Func> action) + { + var profile = GenerationProfile; + if (profile == null) + { + return await action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return await action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + internal static void FilterAllCustomizedMembers(OutputLibrary output) { foreach (var typeProvider in output.TypeProviders) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 38ac362d741..bf4b08b31b5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -6,7 +6,6 @@ using System.Diagnostics; using System.IO; using System.Linq; -using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Build.Construction; @@ -16,7 +15,6 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Formatting; using Microsoft.CodeAnalysis.Simplification; -using Microsoft.CodeAnalysis.Text; using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Providers; using Microsoft.TypeSpec.Generator.Utilities; @@ -40,6 +38,8 @@ internal class GeneratedCodeWorkspace private static readonly Lazy _metadataReferenceResolver = new(() => new WorkspaceMetadataReferenceResolver()); private static Task? _cachedProject; + internal static GeneratedCodeWorkspacePostProcessingProfile? PostProcessingProfile { get; set; } + private static readonly string[] _generatedFolders = [GeneratedFolder]; private static readonly string[] _sharedFolders = [SharedFolder]; @@ -105,7 +105,7 @@ internal static SyntaxTree GetTree(TypeProvider provider) public async IAsyncEnumerable<(string Name, string Text)> GetGeneratedFilesAsync() { - List> documents = new List>(); + List docs = new List(); var memberRemover = new MemberRemoverRewriter(); foreach (Document document in _project.Documents) { @@ -114,9 +114,12 @@ internal static SyntaxTree GetTree(TypeProvider provider) continue; } - documents.Add(ProcessDocument(document, memberRemover)); + docs.Add(document); } - var docs = await Task.WhenAll(documents); + + docs = PostProcessingProfile == null + ? [.. await Task.WhenAll(docs.Select(document => ProcessDocument(document, memberRemover)))] + : await ProcessDocumentsSequentiallyAsync(docs, memberRemover); LoggingHelpers.LogElapsedTime("Roslyn post processing complete"); @@ -132,1073 +135,100 @@ internal static SyntaxTree GetTree(TypeProvider provider) } } - private async Task ProcessDocument(Document document, MemberRemoverRewriter memberRemover) + private async Task> ProcessDocumentsSequentiallyAsync(List documents, MemberRemoverRewriter memberRemover) { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); - - if (semanticModel == null || root == null) - { - return document; - } - - root = memberRemover.Visit(root); - - foreach (var rewriter in CodeModelGenerator.Instance.Rewriters) + List processedDocuments = new(documents.Count); + foreach (var document in documents) { - rewriter.SemanticModel = semanticModel; - root = rewriter.Visit(root); + processedDocuments.Add(await ProcessDocument(document, memberRemover)); } - document = document.WithSyntaxRoot(root); - for (int i = 0; i < 8; i++) - { - var reducedDocument = await ReduceQualifiedNamesAsync(document); - if (ReferenceEquals(reducedDocument, document)) - { - break; - } - - document = reducedDocument; - } - - document = await ReduceSemanticOnlyAsync(document); - document = await ReduceSyntaxOnlyAsync(document); - document = await ReduceDocumentationQualifiedNamesAsync(document); - - // Reformat if any custom rewriters have been applied - if (CodeModelGenerator.Instance.Rewriters.Count > 0) - { - document = await Formatter.FormatAsync(document); - } - return document; + return processedDocuments; } - private static async Task ReduceQualifiedNamesAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); - if (root == null || semanticModel == null) - { - return document; - } - - var safeNameReplacements = new Dictionary(); - foreach (var name in root.DescendantNodes().OfType()) - { - if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax || - name.Parent is QualifiedNameSyntax || - IsInUnsupportedQualifiedNameContext(name)) - { - continue; - } - - var originalSymbol = GetSymbol(semanticModel, name); - if (originalSymbol == null) - { - continue; - } - - if (TryGetNameReplacement(semanticModel, name, originalSymbol, out var replacement)) - { - safeNameReplacements.Add(name, replacement); - } - } - - foreach (var attribute in root.DescendantNodes().OfType()) - { - if (TryGetAttributeNameReplacement(semanticModel, attribute, out var replacement)) - { - safeNameReplacements[attribute.Name] = replacement; - } - } - - var safeMemberAccessReplacements = new Dictionary(); - foreach (var memberAccess in root.DescendantNodes().OfType()) - { - if (IsInUnsupportedQualifiedNameContext(memberAccess)) - { - continue; - } - - if (TryGetMemberAccessReplacement(semanticModel, memberAccess, out var replacement)) - { - safeMemberAccessReplacements.Add(memberAccess, replacement); - } - } - - if (safeNameReplacements.Count == 0 && safeMemberAccessReplacements.Count == 0) - { - return document; - } - - var rewrittenRoot = root.ReplaceNodes( - safeNameReplacements.Keys - .Concat(safeMemberAccessReplacements.Keys), - (original, rewritten) => original switch - { - NameSyntax name => safeNameReplacements[name].WithTriviaFrom(rewritten), - MemberAccessExpressionSyntax memberAccess => safeMemberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), - _ => rewritten - }); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static async Task ReduceSemanticOnlyAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); - if (root == null || semanticModel == null) - { - return document; - } - - var memberAccessReplacements = new Dictionary(); - foreach (var memberAccess in root.DescendantNodes().OfType()) - { - if (!IsInUnsupportedQualifiedNameContext(memberAccess) && - TryGetThisQualificationReplacement(semanticModel, memberAccess, out var replacement)) - { - memberAccessReplacements.Add(memberAccess, replacement); - } - } - - var genericNameReplacements = new Dictionary(); - foreach (var genericName in root.DescendantNodes().OfType()) - { - if (TryGetGenericMethodTypeArgumentsReplacement(semanticModel, genericName, out var replacement)) - { - genericNameReplacements.Add(genericName, replacement); - } - } - - if (memberAccessReplacements.Count == 0 && genericNameReplacements.Count == 0) - { - return document; - } - - var rewrittenRoot = root.ReplaceNodes( - memberAccessReplacements.Keys.Concat(genericNameReplacements.Keys), - (original, rewritten) => original switch - { - GenericNameSyntax genericName => genericNameReplacements[genericName].WithTriviaFrom(rewritten), - MemberAccessExpressionSyntax memberAccess => memberAccessReplacements[memberAccess].WithTriviaFrom(rewritten), - _ => rewritten - }); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static async Task ReduceSyntaxOnlyAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - if (root == null) - { - return document; - } - - var rewriter = new SyntaxOnlyReducer(); - var rewrittenRoot = rewriter.Visit(root); - return rewriter.Changed && rewrittenRoot != null - ? document.WithSyntaxRoot(rewrittenRoot) - : document; - } - - private sealed class SyntaxOnlyReducer : CSharpSyntaxRewriter - { - public bool Changed { get; private set; } - - public override SyntaxNode? VisitParenthesizedExpression(ParenthesizedExpressionSyntax node) - { - var rewritten = (ParenthesizedExpressionSyntax)base.VisitParenthesizedExpression(node)!; - if (CanRemoveParentheses(node)) - { - Changed = true; - return rewritten.Expression.WithTriviaFrom(rewritten); - } - - return rewritten; - } - - public override SyntaxNode? VisitParenthesizedPattern(ParenthesizedPatternSyntax node) - { - var rewritten = (ParenthesizedPatternSyntax)base.VisitParenthesizedPattern(node)!; - Changed = true; - return rewritten.Pattern.WithTriviaFrom(rewritten); - } - - public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) - { - var rewritten = (IdentifierNameSyntax)base.VisitIdentifierName(node)!; - if (rewritten.Identifier.ValueText is not ("Byte" or "Char" or "String") || - node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) - { - return rewritten; - } - - Changed = true; - return GetPredefinedType(rewritten.Identifier.ValueText).WithTriviaFrom(rewritten); - } - - public override SyntaxNode? VisitQualifiedName(QualifiedNameSyntax node) - { - var rewritten = (QualifiedNameSyntax)base.VisitQualifiedName(node)!; - if (rewritten.Left is not IdentifierNameSyntax { Identifier.ValueText: "System" } || - rewritten.Right.Identifier.ValueText is not ("Byte" or "Char" or "String") || - node.Parent is MemberAccessExpressionSyntax or QualifiedNameSyntax) - { - return rewritten; - } - - Changed = true; - return GetPredefinedType(rewritten.Right.Identifier.ValueText).WithTriviaFrom(rewritten); - } - - public override SyntaxNode? VisitCastExpression(CastExpressionSyntax node) - { - var rewritten = (CastExpressionSyntax)base.VisitCastExpression(node)!; - if (rewritten.Expression is LiteralExpressionSyntax literalExpression && - (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || - literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && - (rewritten.Parent is EqualsValueClauseSyntax || node.Parent is EqualsValueClauseSyntax)) - { - Changed = true; - return rewritten.Expression.WithTriviaFrom(rewritten); - } - - return rewritten; - } - - public override SyntaxNode? VisitEqualsValueClause(EqualsValueClauseSyntax node) - { - var rewritten = (EqualsValueClauseSyntax)base.VisitEqualsValueClause(node)!; - if (rewritten.Value is CastExpressionSyntax { Expression: LiteralExpressionSyntax literalExpression } castExpression && - (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || - literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression))) - { - Changed = true; - return rewritten.WithValue(castExpression.Expression.WithTriviaFrom(castExpression)); - } - - return rewritten; - } - - private static PredefinedTypeSyntax GetPredefinedType(string typeName) => typeName switch - { - "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)), - "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)), - "String" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.StringKeyword)), - _ => throw new InvalidOperationException($"Unexpected predefined type name: {typeName}") - }; - } - - private static ISymbol? GetSymbol(SemanticModel semanticModel, NameSyntax name) => - semanticModel.GetSymbolInfo(name).Symbol ?? - semanticModel.GetTypeInfo(name).Type; - - private static async Task ReduceParenthesesAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - if (root == null) - { - return document; - } - - var expressions = root.DescendantNodes() - .OfType() - .Where(CanRemoveParentheses) - .ToList(); - var patterns = root.DescendantNodes() - .OfType() - .ToList(); - if (expressions.Count == 0 && patterns.Count == 0) - { - return document; - } - - var rewrittenRoot = root - .ReplaceNodes( - expressions, - static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)) - .ReplaceNodes( - patterns, - static (_, rewritten) => rewritten.Pattern.WithTriviaFrom(rewritten)); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static bool CanRemoveParentheses(ParenthesizedExpressionSyntax node) => node.Parent switch - { - ParenthesizedExpressionSyntax => true, - IfStatementSyntax ifStatement when ifStatement.Condition == node => true, - WhileStatementSyntax whileStatement when whileStatement.Condition == node => true, - DoStatementSyntax doStatement when doStatement.Condition == node => true, - ForStatementSyntax forStatement when forStatement.Condition == node => true, - SwitchStatementSyntax switchStatement when switchStatement.Expression == node => true, - ReturnStatementSyntax => true, - ArrowExpressionClauseSyntax => true, - EqualsValueClauseSyntax => true, - AssignmentExpressionSyntax assignment when assignment.Right == node => true, - ExpressionStatementSyntax when node.Expression is AssignmentExpressionSyntax => true, - ArgumentSyntax => node.Expression is not AssignmentExpressionSyntax and not ConditionalExpressionSyntax, - BracketedArgumentListSyntax => true, - ConditionalExpressionSyntax parent when parent.Condition == node => true, - WhenClauseSyntax whenClause when whenClause.Condition == node => true, - SwitchExpressionArmSyntax switchArm when switchArm.WhenClause?.Condition == node => true, - BinaryExpressionSyntax parent when parent.Left == node => - GetExpressionPrecedence(node.Expression) >= GetExpressionPrecedence(parent), - BinaryExpressionSyntax parent when parent.Right == node => - GetExpressionPrecedence(node.Expression) > GetExpressionPrecedence(parent) || - parent.IsKind(SyntaxKind.CoalesceExpression) && node.Expression.IsKind(SyntaxKind.CoalesceExpression), - PrefixUnaryExpressionSyntax => node.Expression is CastExpressionSyntax, - _ => false - }; - - private static int GetExpressionPrecedence(ExpressionSyntax expression) => expression.Kind() switch - { - SyntaxKind.SimpleMemberAccessExpression or - SyntaxKind.ElementAccessExpression or - SyntaxKind.InvocationExpression => 15, - SyntaxKind.CastExpression => 14, - SyntaxKind.UnaryMinusExpression or - SyntaxKind.UnaryPlusExpression or - SyntaxKind.LogicalNotExpression or - SyntaxKind.BitwiseNotExpression => 13, - SyntaxKind.MultiplyExpression or - SyntaxKind.DivideExpression or - SyntaxKind.ModuloExpression => 12, - SyntaxKind.AddExpression or - SyntaxKind.SubtractExpression => 11, - SyntaxKind.LeftShiftExpression or - SyntaxKind.RightShiftExpression => 10, - SyntaxKind.LessThanExpression or - SyntaxKind.LessThanOrEqualExpression or - SyntaxKind.GreaterThanExpression or - SyntaxKind.GreaterThanOrEqualExpression or - SyntaxKind.IsExpression or - SyntaxKind.AsExpression => 9, - SyntaxKind.EqualsExpression or - SyntaxKind.NotEqualsExpression => 8, - SyntaxKind.BitwiseAndExpression => 7, - SyntaxKind.ExclusiveOrExpression => 6, - SyntaxKind.BitwiseOrExpression => 5, - SyntaxKind.LogicalAndExpression => 4, - SyntaxKind.LogicalOrExpression => 3, - SyntaxKind.CoalesceExpression => 2, - SyntaxKind.SimpleAssignmentExpression or - SyntaxKind.AddAssignmentExpression or - SyntaxKind.SubtractAssignmentExpression or - SyntaxKind.MultiplyAssignmentExpression or - SyntaxKind.DivideAssignmentExpression or - SyntaxKind.ModuloAssignmentExpression or - SyntaxKind.AndAssignmentExpression or - SyntaxKind.ExclusiveOrAssignmentExpression or - SyntaxKind.OrAssignmentExpression or - SyntaxKind.LeftShiftAssignmentExpression or - SyntaxKind.RightShiftAssignmentExpression or - SyntaxKind.CoalesceAssignmentExpression => 1, - _ => 16 - }; - - private static async Task ReduceGenericMethodTypeArgumentsAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); - if (root == null || semanticModel == null) - { - return document; - } - - var replacements = new Dictionary(); - foreach (var genericName in root.DescendantNodes().OfType()) - { - if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || - memberAccess.Name != genericName || - memberAccess.Parent is not InvocationExpressionSyntax invocation || - invocation.Expression != memberAccess) - { - continue; - } - - var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; - if (originalSymbol == null) - { - continue; - } - - var candidateName = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); - var candidateInvocation = invocation.WithExpression(memberAccess.WithName(candidateName)); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - invocation.SpanStart, - candidateInvocation, - SpeculativeBindingOption.BindAsExpression).Symbol; - if (speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) - { - replacements.Add(genericName, candidateName); - } - } - - if (replacements.Count == 0) - { - return document; - } - - var rewrittenRoot = root.ReplaceNodes( - replacements.Keys, - (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static bool TryGetGenericMethodTypeArgumentsReplacement( - SemanticModel semanticModel, - GenericNameSyntax genericName, - out IdentifierNameSyntax replacement) - { - replacement = SyntaxFactory.IdentifierName(genericName.Identifier).WithTriviaFrom(genericName); - if (genericName.Parent is not MemberAccessExpressionSyntax memberAccess || - memberAccess.Name != genericName || - memberAccess.Parent is not InvocationExpressionSyntax invocation || - invocation.Expression != memberAccess) - { - return false; - } - - var originalSymbol = semanticModel.GetSymbolInfo(invocation).Symbol; - if (originalSymbol == null) - { - return false; - } - - var candidateInvocation = invocation.WithExpression(memberAccess.WithName(replacement)); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - invocation.SpanStart, - candidateInvocation, - SpeculativeBindingOption.BindAsExpression).Symbol; - return speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); - } - - private static async Task ReduceThisQualificationAsync(Document document) + private async Task ProcessDocument(Document document, MemberRemoverRewriter memberRemover) { - var root = await document.GetSyntaxRootAsync(); - var semanticModel = await document.GetSemanticModelAsync(); - if (root == null || semanticModel == null) + var totalStopwatch = PostProcessingProfile == null ? null : Stopwatch.StartNew(); + try { - return document; - } + var root = await MeasurePostProcessingStepAsync("GetSyntaxRootAsync", () => document.GetSyntaxRootAsync()); + var semanticModel = await MeasurePostProcessingStepAsync("GetSemanticModelAsync", () => document.GetSemanticModelAsync()); - var replacements = new Dictionary(); - foreach (var memberAccess in root.DescendantNodes().OfType()) - { - if (memberAccess.Expression is not ThisExpressionSyntax || - IsInUnsupportedQualifiedNameContext(memberAccess)) + if (semanticModel == null || root == null) { - continue; + return document; } - var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; - var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess - ? semanticModel.GetSymbolInfo(invocation).Symbol - : null; - if (originalSymbol == null) - { - originalSymbol = originalInvocationSymbol; - if (originalSymbol == null) - { - continue; - } - } + root = MeasurePostProcessingStep("MemberRemoverRewriter", () => memberRemover.Visit(root)); - var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - memberAccess.SpanStart, - candidate, - SpeculativeBindingOption.BindAsExpression).Symbol; - if (speculativeSymbol == null && - memberAccess.Parent is InvocationExpressionSyntax parentInvocation && - parentInvocation.Expression == memberAccess && - originalInvocationSymbol != null) + foreach (var rewriter in CodeModelGenerator.Instance.Rewriters) { - var candidateInvocation = parentInvocation.WithExpression(candidate); - speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - parentInvocation.SpanStart, - candidateInvocation, - SpeculativeBindingOption.BindAsExpression).Symbol; - originalSymbol = originalInvocationSymbol; + rewriter.SemanticModel = semanticModel; + root = MeasurePostProcessingStep($"CustomRewriter.{rewriter.GetType().Name}", () => rewriter.Visit(root)); } + document = document.WithSyntaxRoot(root); - if (speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) - { - replacements.Add(memberAccess, candidate); - } - } - - if (replacements.Count == 0) - { - return document; - } - - var rewrittenRoot = root.ReplaceNodes( - replacements.Keys, - (original, rewritten) => replacements[original].WithTriviaFrom(rewritten)); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static bool TryGetThisQualificationReplacement( - SemanticModel semanticModel, - MemberAccessExpressionSyntax memberAccess, - out ExpressionSyntax replacement) - { - replacement = memberAccess; - if (memberAccess.Expression is not ThisExpressionSyntax) - { - return false; - } + document = await MeasurePostProcessingStepAsync("Roslyn.Simplifier.ReduceAsync", () => Simplifier.ReduceAsync(document)); - var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; - var originalInvocationSymbol = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess - ? semanticModel.GetSymbolInfo(invocation).Symbol - : null; - if (originalSymbol == null) - { - originalSymbol = originalInvocationSymbol; - if (originalSymbol == null) + // Reformat if any custom rewriters have been applied + if (CodeModelGenerator.Instance.Rewriters.Count > 0) { - return false; + document = await MeasurePostProcessingStepAsync("Formatter.FormatAsync", () => Formatter.FormatAsync(document)); } - } - - var candidate = memberAccess.Name.WithTriviaFrom(memberAccess); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - memberAccess.SpanStart, - candidate, - SpeculativeBindingOption.BindAsExpression).Symbol; - if (speculativeSymbol == null && - memberAccess.Parent is InvocationExpressionSyntax parentInvocation && - parentInvocation.Expression == memberAccess && - originalInvocationSymbol != null) - { - var candidateInvocation = parentInvocation.WithExpression(candidate); - speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - parentInvocation.SpanStart, - candidateInvocation, - SpeculativeBindingOption.BindAsExpression).Symbol; - originalSymbol = originalInvocationSymbol; - } - - if (speculativeSymbol == null || - !SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) - { - return false; - } - - replacement = candidate; - return true; - } - - private static async Task ReducePredefinedTypeNamesAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - if (root == null) - { return document; } - - var identifiers = root.DescendantNodes() - .OfType() - .Where(static identifier => - identifier.Identifier.ValueText is "Byte" or "Char" && - identifier.Parent is not MemberAccessExpressionSyntax and not QualifiedNameSyntax) - .ToList(); - var castExpressions = root.DescendantNodes() - .OfType() - .Where(static castExpression => - castExpression.Expression is LiteralExpressionSyntax literalExpression && - (literalExpression.IsKind(SyntaxKind.NullLiteralExpression) || - literalExpression.IsKind(SyntaxKind.DefaultLiteralExpression)) && - castExpression.Parent is EqualsValueClauseSyntax) - .ToList(); - - if (identifiers.Count == 0 && castExpressions.Count == 0) + finally { - return document; - } - - var rewrittenRoot = root - .ReplaceNodes( - identifiers, - static (_, rewritten) => rewritten.Identifier.ValueText switch - { - "Byte" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.ByteKeyword)).WithTriviaFrom(rewritten), - "Char" => SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.CharKeyword)).WithTriviaFrom(rewritten), - _ => rewritten - }) - .ReplaceNodes( - castExpressions, - static (_, rewritten) => rewritten.Expression.WithTriviaFrom(rewritten)); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static async Task ReduceDocumentationQualifiedNamesAsync(Document document) - { - var root = await document.GetSyntaxRootAsync(); - if (root == null) - { - return document; - } - - var documentationTrivia = root.DescendantTrivia(descendIntoTrivia: true) - .Where(static trivia => - trivia.HasStructure && - trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia) && - (trivia.ToFullString().Contains("global::", StringComparison.Ordinal) || - trivia.ToFullString().Contains("cref=\"", StringComparison.Ordinal))) - .ToList(); - if (documentationTrivia.Count == 0) - { - return document; - } - - var rewrittenRoot = root.ReplaceTrivia( - documentationTrivia, - (original, rewritten) => - { - var reduced = ReduceDocumentationTriviaText(root, original, rewritten.ToFullString()); - var parsedTrivia = SyntaxFactory.ParseLeadingTrivia(reduced); - return parsedTrivia.Count == 1 ? parsedTrivia[0] : rewritten; - }); - return document.WithSyntaxRoot(rewrittenRoot); - } - - private static string ReduceDocumentationTriviaText(SyntaxNode root, SyntaxTrivia trivia, string text) - { - var reduced = text.Replace("global::", string.Empty, StringComparison.Ordinal); - var namespacePrefix = trivia.Token.Parent? - .AncestorsAndSelf() - .OfType() - .FirstOrDefault()? - .Name - .ToString(); - - var namespacePrefixes = root.DescendantNodes() - .OfType() - .Where(static directive => directive is { Alias: null, StaticKeyword.RawKind: 0, Name: not null }) - .Select(static directive => directive.Name!.ToString()) - .Append(namespacePrefix) - .Where(static prefix => !string.IsNullOrEmpty(prefix)) - .Distinct(StringComparer.Ordinal) - .OrderByDescending(static prefix => prefix!.Length); - - var prefixes = namespacePrefixes.Select(static prefix => prefix! + ".").ToArray(); - return Regex.Replace( - reduced, - @"(?(?:cref|name)="")(?[^""]*)(?"")", - match => + if (totalStopwatch != null) { - var value = match.Groups["value"].Value; - if (IsMockingReturnsCref(reduced, match.Index)) - { - value = value.Replace("SampleTypeSpec.Models.Custom.", "Models.Custom.", StringComparison.Ordinal); - return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; - } - - if (IsAbstractDerivedTypesCref(trivia, reduced, match.Index)) - { - return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; - } - - foreach (var prefix in prefixes) - { - value = value.Replace(prefix, string.Empty, StringComparison.Ordinal); - } - - return match.Groups["attribute"].Value + value + match.Groups["quote"].Value; - }); - } - - private static bool IsMockingReturnsCref(string text, int index) - { - var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); - var lineEnd = text.IndexOf('\n', index); - lineStart = lineStart < 0 ? 0 : lineStart + 1; - lineEnd = lineEnd < 0 ? text.Length : lineEnd; - return text.Substring(lineStart, lineEnd - lineStart).Contains(" instance for mocking.", StringComparison.Ordinal); - } - - private static bool IsAbstractDerivedTypesCref(SyntaxTrivia trivia, string text, int index) - { - if (trivia.Token.Parent? - .AncestorsAndSelf() - .OfType() - .FirstOrDefault()? - .Identifier.ValueText.EndsWith("ModelFactory", StringComparison.Ordinal) != true) - { - return false; - } - - var lineStart = text.LastIndexOf('\n', Math.Max(0, index - 1)); - var lineEnd = text.IndexOf('\n', index); - lineStart = lineStart < 0 ? 0 : lineStart + 1; - lineEnd = lineEnd < 0 ? text.Length : lineEnd; - return text.Substring(lineStart, lineEnd - lineStart).Contains("derived classes available for instantiation", StringComparison.Ordinal); - } - - private static bool TryGetNameReplacement( - SemanticModel semanticModel, - NameSyntax originalName, - ISymbol originalSymbol, - out NameSyntax replacement) - { - replacement = originalName; - if (!TryGetNameParts(originalName, out var parts)) - { - return false; - } - - for (int i = parts.Count - 1; i >= 0; i--) - { - var candidate = BuildName(parts, i).WithTriviaFrom(originalName); - if (SpeculativelyBindsToSameSymbol(semanticModel, originalName, candidate, originalSymbol)) - { - replacement = candidate; - return true; + totalStopwatch.Stop(); + PostProcessingProfile?.Add("ProcessDocument.Total", totalStopwatch.Elapsed, 0); } } - - return false; } - private static bool SpeculativelyBindsToSameSymbol( - SemanticModel semanticModel, - NameSyntax originalName, - NameSyntax replacement, - ISymbol originalSymbol) + private static T MeasurePostProcessingStep(string stepName, Func action) { - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - originalName.SpanStart, - replacement, - SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; - if (speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) + var profile = PostProcessingProfile; + if (profile == null) { - return true; + return action(); } - if (originalSymbol is ITypeSymbol originalType) - { - var speculativeType = semanticModel.GetSpeculativeTypeInfo( - originalName.SpanStart, - replacement, - SpeculativeBindingOption.BindAsTypeOrNamespace).Type; - if (speculativeType != null && - SymbolEqualityComparer.Default.Equals(originalType, speculativeType)) - { - return true; - } - } - - if (originalName.Parent is MemberAccessExpressionSyntax memberAccess && - memberAccess.Expression == originalName) - { - speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - originalName.SpanStart, - replacement, - SpeculativeBindingOption.BindAsExpression).Symbol; - return speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol); - } - - return false; - } - - private static bool TryGetNameParts(NameSyntax name, out IReadOnlyList parts) - { - var builder = new List(); - if (AddNameParts(name, builder)) - { - parts = builder; - return true; - } - - parts = []; - return false; - } - - private static bool AddNameParts(NameSyntax name, List parts) - { - switch (name) - { - case SimpleNameSyntax simpleName: - parts.Add(simpleName); - return true; - case QualifiedNameSyntax qualifiedName: - if (!AddNameParts(qualifiedName.Left, parts)) - { - return false; - } - - parts.Add(qualifiedName.Right); - return true; - case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: - parts.Add(aliasQualifiedName.Name); - return true; - default: - return false; - } - } - - private static NameSyntax BuildName(IReadOnlyList parts, int startIndex) - { - NameSyntax name = parts[startIndex]; - for (int i = startIndex + 1; i < parts.Count; i++) - { - name = SyntaxFactory.QualifiedName(name, parts[i]); - } - - return name; - } - - private static SimpleNameSyntax GetRightmostName(NameSyntax name) => name switch - { - QualifiedNameSyntax qualifiedName => qualifiedName.Right, - AliasQualifiedNameSyntax aliasQualifiedName => aliasQualifiedName.Name, - SimpleNameSyntax simpleName => simpleName, - _ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}") - }; - - private static bool TryGetAttributeNameReplacement( - SemanticModel semanticModel, - AttributeSyntax attribute, - out NameSyntax replacement) - { - replacement = attribute.Name; - if (attribute.Name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax) - { - return false; - } - - var originalSymbol = semanticModel.GetSymbolInfo(attribute).Symbol; - if (originalSymbol is not IMethodSymbol { ContainingType: { } originalAttributeType }) - { - return false; - } - - var rightmostName = GetRightmostName(attribute.Name); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - attribute.Name.SpanStart, - rightmostName, - SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol; - if (!SymbolEqualityComparer.Default.Equals(originalAttributeType, speculativeSymbol)) - { - return false; - } - - replacement = TrimAttributeSuffix(rightmostName).WithTriviaFrom(attribute.Name); - return true; - } - - private static SimpleNameSyntax TrimAttributeSuffix(SimpleNameSyntax name) - { - const string AttributeSuffix = "Attribute"; - var identifier = name.Identifier; - var text = identifier.ValueText; - if (!text.EndsWith(AttributeSuffix, StringComparison.Ordinal) || text.Length == AttributeSuffix.Length) - { - return name; - } - - return SyntaxFactory.IdentifierName( - SyntaxFactory.Identifier( - identifier.LeadingTrivia, - text.Substring(0, text.Length - AttributeSuffix.Length), - identifier.TrailingTrivia)); - } - - private static bool TryGetMemberAccessReplacement( - SemanticModel semanticModel, - MemberAccessExpressionSyntax memberAccess, - out ExpressionSyntax replacement) - { - replacement = memberAccess; - var originalSymbol = semanticModel.GetSymbolInfo(memberAccess).Symbol; - var invocationExpression = memberAccess.Parent is InvocationExpressionSyntax invocation && invocation.Expression == memberAccess - ? invocation - : null; - var originalInvocationSymbol = invocationExpression != null - ? semanticModel.GetSymbolInfo(invocationExpression).Symbol - : null; - if (memberAccess.Expression is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.IntKeyword or (int)SyntaxKind.FloatKeyword } && - originalInvocationSymbol != null && - TryReduceInvocationExpression(semanticModel, invocationExpression!, memberAccess.Name, originalInvocationSymbol)) - { - replacement = memberAccess.Name.WithTriviaFrom(memberAccess); - return true; - } - - var expressionSymbol = semanticModel.GetSymbolInfo(memberAccess.Expression).Symbol; - if (expressionSymbol is not null and not INamespaceSymbol and not INamedTypeSymbol) + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try { - return false; + return action(); } - - if (originalSymbol == null || - !TryGetMemberAccessParts(memberAccess, out var parts) || - parts.Count < 2) + finally { - originalSymbol = originalInvocationSymbol; - if (originalSymbol == null || - !TryGetMemberAccessParts(memberAccess, out parts) || - parts.Count < 2) - { - return false; - } + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); } - - for (int i = parts.Count - 1; i > 0; i--) - { - var candidate = BuildMemberAccess(parts, i).WithTriviaFrom(memberAccess); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - memberAccess.SpanStart, - candidate, - SpeculativeBindingOption.BindAsExpression).Symbol; - if (speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol)) - { - replacement = candidate; - return true; - } - - if (memberAccess.Parent is InvocationExpressionSyntax parentInvocation && - parentInvocation.Expression == memberAccess && - originalInvocationSymbol != null && - TryReduceInvocationExpression(semanticModel, parentInvocation, candidate, originalInvocationSymbol)) - { - replacement = candidate; - return true; - } - } - - return false; - } - - private static bool TryReduceInvocationExpression( - SemanticModel semanticModel, - InvocationExpressionSyntax invocation, - ExpressionSyntax candidateExpression, - ISymbol originalInvocationSymbol) - { - var candidateInvocation = invocation.WithExpression(candidateExpression); - var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo( - invocation.SpanStart, - candidateInvocation, - SpeculativeBindingOption.BindAsExpression).Symbol; - return speculativeSymbol != null && - SymbolEqualityComparer.Default.Equals(originalInvocationSymbol, speculativeSymbol); } - private static bool TryGetMemberAccessParts(ExpressionSyntax expression, out IReadOnlyList parts) + private static async Task MeasurePostProcessingStepAsync(string stepName, Func> action) { - var builder = new List(); - if (AddMemberAccessParts(expression, builder)) + var profile = PostProcessingProfile; + if (profile == null) { - parts = builder; - return true; + return await action(); } - parts = []; - return false; - } - - private static bool AddMemberAccessParts(SyntaxNode expression, List parts) - { - switch (expression) - { - case SimpleNameSyntax name: - parts.Add(name); - return true; - case MemberAccessExpressionSyntax memberAccess: - if (!AddMemberAccessParts(memberAccess.Expression, parts)) - { - return false; - } - - parts.Add(memberAccess.Name); - return true; - case QualifiedNameSyntax qualifiedName: - if (!AddMemberAccessParts(qualifiedName.Left, parts)) - { - return false; - } - - parts.Add(qualifiedName.Right); - return true; - case AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" } aliasQualifiedName: - parts.Add(aliasQualifiedName.Name); - return true; - default: - return false; - } - } - - private static ExpressionSyntax BuildMemberAccess(IReadOnlyList parts, int startIndex) - { - ExpressionSyntax expression = parts[startIndex]; - for (int i = startIndex + 1; i < parts.Count; i++) + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try { - expression = SyntaxFactory.MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - expression, - parts[i]); + return await action(); } - - return expression; - } - - private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) => - name.Ancestors().Any(static ancestor => - ancestor is UsingDirectiveSyntax || - ancestor is CrefSyntax); - - private static bool IsInUnsupportedQualifiedNameContext(ExpressionSyntax expression) => - expression.Ancestors().Any(static ancestor => - ancestor is CrefSyntax); - - private static bool ContainsSimplifierAnnotations(SyntaxNode root) => - root.HasAnnotation(Simplifier.Annotation) || - root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken => - nodeOrToken.HasAnnotation(Simplifier.Annotation)); - - private static IReadOnlyList GetSimplifierSpans(SyntaxNode root) - { - List spans = new(); - foreach (var member in root.DescendantNodes().OfType()) + finally { - if (ContainsReducibleSyntax(member)) - { - spans.Add(member.FullSpan); - } + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); } - - spans.AddRange(root - .DescendantNodesAndTokens(descendIntoTrivia: true) - .Where(static nodeOrToken => nodeOrToken.HasAnnotation(Simplifier.Annotation)) - .Select(static nodeOrToken => nodeOrToken.FullSpan)); - - return spans; } - private static bool ContainsReducibleSyntax(SyntaxNode root) => - root.DescendantNodes( - descendIntoChildren: node => node == root || node is not MemberDeclarationSyntax, - descendIntoTrivia: true).Any(static node => - node is ThisExpressionSyntax || - node is ParenthesizedExpressionSyntax || - node is CrefSyntax || - node is QualifiedNameSyntax || - node is MemberAccessExpressionSyntax || - node is AssignmentExpressionSyntax { RawKind: (int)SyntaxKind.SimpleAssignmentExpression } || - node is AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" }); - public static bool IsGeneratedDocument(Document document) => document.Folders.Contains(GeneratedFolder); public static bool IsCustomDocument(Document document) => !IsGeneratedDocument(document); public static bool IsGeneratedTestDocument(Document document) => document.Folders.Contains(GeneratedTestFolder); @@ -1313,11 +343,11 @@ public async Task PostProcessAsync() case Configuration.UnreferencedTypesHandlingOption.KeepAll: break; case Configuration.UnreferencedTypesHandlingOption.Internalize: - _project = await postProcessor.InternalizeAsync(_project); + _project = await MeasurePostProcessingStepAsync("PostProcess.InternalizeAsync", () => postProcessor.InternalizeAsync(_project)); break; case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize: - _project = await postProcessor.InternalizeAsync(_project); - _project = await postProcessor.RemoveAsync(_project); + _project = await MeasurePostProcessingStepAsync("PostProcess.InternalizeAsync", () => postProcessor.InternalizeAsync(_project)); + _project = await MeasurePostProcessingStepAsync("PostProcess.RemoveAsync", () => postProcessor.RemoveAsync(_project)); break; } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs new file mode 100644 index 00000000000..74a9a58168a --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspacePostProcessingProfile.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed class GeneratedCodeWorkspacePostProcessingProfile + { + private readonly object _syncRoot = new(); + private readonly Dictionary _steps = new(StringComparer.Ordinal); + + public void Add(string stepName, TimeSpan elapsed, long allocatedBytes) + { + lock (_syncRoot) + { + ref var summary = ref CollectionsMarshal.GetValueRefOrAddDefault(_steps, stepName, out _); + summary.Count++; + summary.ElapsedTicks += elapsed.Ticks; + summary.AllocatedBytes += allocatedBytes; + } + } + + public string GetSummary() + { + KeyValuePair[] steps; + lock (_syncRoot) + { + steps = _steps.ToArray(); + } + + var totalTicks = steps + .Where(static step => step.Key != "ProcessDocument.Total") + .Sum(static step => step.Value.ElapsedTicks); + var builder = new StringBuilder(); + builder.AppendLine("Post-processing step profile:"); + builder.AppendLine("Step, Count, Total ms, Avg ms, Percent of measured steps, Allocated bytes, Avg allocated bytes"); + + foreach (var step in steps.OrderByDescending(static step => step.Value.ElapsedTicks)) + { + var elapsedMs = TimeSpan.FromTicks(step.Value.ElapsedTicks).TotalMilliseconds; + var averageMs = elapsedMs / step.Value.Count; + var averageAllocatedBytes = step.Value.AllocatedBytes / step.Value.Count; + var percentage = totalTicks == 0 || step.Key == "ProcessDocument.Total" + ? 0 + : step.Value.ElapsedTicks * 100.0 / totalTicks; + builder.AppendLine($"{step.Key}, {step.Value.Count}, {elapsedMs:F3}, {averageMs:F3}, {percentage:F1}%, {step.Value.AllocatedBytes}, {averageAllocatedBytes}"); + } + + return builder.ToString(); + } + + private struct StepSummary + { + public int Count; + public long ElapsedTicks; + public long AllocatedBytes; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index be96e11df59..d5db7e47298 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -20,6 +20,8 @@ internal class PostProcessor private readonly HashSet _typesToKeep; private INamedTypeSymbol? _modelFactorySymbol; + private static GeneratedCodeWorkspacePostProcessingProfile? Profile => GeneratedCodeWorkspace.PostProcessingProfile; + public PostProcessor( HashSet typesToKeep, string? modelFactoryFullName = null, @@ -125,40 +127,55 @@ protected virtual bool ShouldIncludeDocument(Document document) => /// The processed . is immutable, therefore this should usually be a new instance public async Task InternalizeAsync(Project project) { - var compilation = await project.GetCompilationAsync(); + var compilation = await MeasureAsync("PostProcessor.Internalize.GetCompilationAsync", () => project.GetCompilationAsync()); if (compilation == null) return project; // first get all the declared symbols - var definitions = await GetTypeSymbolsAsync(compilation, project, true); + var definitions = await MeasureAsync("PostProcessor.Internalize.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, true)); // build the reference map var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache); + await MeasureAsync( + "PostProcessor.Internalize.BuildPublicReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DeclaredNodesCache)); // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); + var rootSymbols = await MeasureAsync("PostProcessor.Internalize.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + var publicSymbols = Measure("PostProcessor.Internalize.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray()); var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); - var nodesToInternalize = new Dictionary(); - foreach (var symbol in symbolsToInternalize) + var nodesToInternalize = Measure("PostProcessor.Internalize.CollectNodes", () => { - foreach (var node in definitions.DeclaredNodesCache[symbol]) + var nodes = new Dictionary(); + foreach (var symbol in symbolsToInternalize) { - nodesToInternalize[node] = project.GetDocumentId(node.SyntaxTree)!; + foreach (var node in definitions.DeclaredNodesCache[symbol]) + { + nodes[node] = project.GetDocumentId(node.SyntaxTree)!; + } } - } - foreach (var (model, documentId) in nodesToInternalize) + return nodes; + }); + + project = Measure("PostProcessor.Internalize.MarkInternal", () => { - project = MarkInternal(project, model, documentId); - } + var updatedProject = project; + foreach (var (model, documentId) in nodesToInternalize) + { + updatedProject = MarkInternal(updatedProject, model, documentId); + } + + return updatedProject; + }); var modelNamesToRemove = nodesToInternalize.Keys.Select(item => item.Identifier.Text); - project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet()); + project = await MeasureAsync( + "PostProcessor.Internalize.RemoveMethodsFromModelFactoryAsync", + () => RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet())); return project; } @@ -232,42 +249,49 @@ private async Task RemoveMethodsFromModelFactoryAsync(Project project, /// The processed . is immutable, therefore this should usually be a new instance public async Task RemoveAsync(Project project) { - var compilation = await project.GetCompilationAsync(); + var compilation = await MeasureAsync("PostProcessor.Remove.GetCompilationAsync", () => project.GetCompilationAsync()); if (compilation == null) return project; // find all the declarations, including non-public declared - var definitions = await GetTypeSymbolsAsync(compilation, project, false); + var definitions = await MeasureAsync("PostProcessor.Remove.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, false)); // build reference map var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache); + await MeasureAsync( + "PostProcessor.Remove.BuildAllReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache)); // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); + var rootSymbols = await MeasureAsync("PostProcessor.Remove.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal // helpers that are required by the model factory. if (_modelFactorySymbol != null) rootSymbols.Add(_modelFactorySymbol); // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + var referencedSymbols = Measure("PostProcessor.Remove.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray().AsEnumerable()); - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + referencedSymbols = Measure("PostProcessor.Remove.AddSampleSymbols", () => AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols)); + var referencedSet = Measure("PostProcessor.Remove.BuildReferencedSet", () => new HashSet(referencedSymbols, SymbolEqualityComparer.Default)); var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); - var nodesToRemove = new List(); - foreach (var symbol in symbolsToRemove) + var nodesToRemove = Measure("PostProcessor.Remove.CollectNodes", () => { - if (referencedSet.Contains(GetBase(symbol))) + var nodes = new List(); + foreach (var symbol in symbolsToRemove) { - continue; + if (referencedSet.Contains(GetBase(symbol))) + { + continue; + } + nodes.AddRange(definitions.DeclaredNodesCache[symbol]); } - nodesToRemove.AddRange(definitions.DeclaredNodesCache[symbol]); - } + + return nodes; + }); // remove them one by one - project = await RemoveModelsAsync(project, nodesToRemove); + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync", () => RemoveModelsAsync(project, nodesToRemove)); return project; } @@ -349,25 +373,35 @@ private async Task RemoveModelsAsync(Project project, IEnumerable unusedModels) { // accumulate the definitions from the same document together - var documents = new Dictionary>(); - - foreach (var model in unusedModels) + var documents = Measure("PostProcessor.Remove.RemoveModelsAsync.GroupByDocument", () => { - var document = project.GetDocument(model.SyntaxTree); - Debug.Assert(document != null); - if (!documents.ContainsKey(document)) - documents.Add(document, new HashSet()); + var groupedDocuments = new Dictionary>(); + foreach (var model in unusedModels) + { + var document = project.GetDocument(model.SyntaxTree); + Debug.Assert(document != null); + if (!groupedDocuments.ContainsKey(document)) + groupedDocuments.Add(document, new HashSet()); - documents[document].Add(model); - } + groupedDocuments[document].Add(model); + } - foreach (var models in documents.Values) + return groupedDocuments; + }); + + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync.RemoveModelsFromDocuments", async () => { - project = await RemoveModelsFromDocumentAsync(project, models); - } + var updatedProject = project; + foreach (var models in documents.Values) + { + updatedProject = await RemoveModelsFromDocumentAsync(updatedProject, models); + } + + return updatedProject; + }); // remove what are now invalid references due to the models being removed - project = await RemoveInvalidRefs(project); + project = await MeasureAsync("PostProcessor.Remove.RemoveModelsAsync.RemoveInvalidRefs", () => RemoveInvalidRefs(project)); return project; } @@ -418,16 +452,28 @@ private async Task RemoveInvalidRefs(Project project) var solution = project.Solution; // Process each document for invalid usings - foreach (var documentId in project.DocumentIds) + solution = await MeasureAsync("PostProcessor.Remove.RemoveInvalidRefs.RemoveInvalidUsings", async () => { - solution = await RemoveInvalidUsings(solution, documentId); - } + var updatedSolution = solution; + foreach (var documentId in project.DocumentIds) + { + updatedSolution = await RemoveInvalidUsings(updatedSolution, documentId); + } + + return updatedSolution; + }); // Process each document for invalid attributes (with fresh semantic models) - foreach (var documentId in project.DocumentIds) + solution = await MeasureAsync("PostProcessor.Remove.RemoveInvalidRefs.RemoveInvalidAttributes", async () => { - solution = await RemoveInvalidAttributes(solution, documentId); - } + var updatedSolution = solution; + foreach (var documentId in project.DocumentIds) + { + updatedSolution = await RemoveInvalidAttributes(updatedSolution, documentId); + } + + return updatedSolution; + }); return solution.GetProject(project.Id)!; } @@ -533,6 +579,48 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && return solution; } + private static T Measure(string stepName, Func action) + { + var profile = Profile; + if (profile == null) + { + return action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + + private static async Task MeasureAsync(string stepName, Func> action) + { + var profile = Profile; + if (profile == null) + { + return await action(); + } + + var allocatedBytes = GC.GetTotalAllocatedBytes(precise: false); + var stopwatch = Stopwatch.StartNew(); + try + { + return await action(); + } + finally + { + stopwatch.Stop(); + profile.Add(stepName, stopwatch.Elapsed, GC.GetTotalAllocatedBytes(precise: false) - allocatedBytes); + } + } + private async Task> GetRootSymbolsAsync(Project project, TypeSymbols modelSymbols) { var result = new HashSet(SymbolEqualityComparer.Default); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs index d1aec2a8486..0a75d8c9360 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/GeneratedCodeWorkspaceTests.cs @@ -4,7 +4,6 @@ using Microsoft.Build.Construction; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; -using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Tests.Common; using NUnit.Framework; using System; @@ -98,514 +97,6 @@ await MockHelpers.LoadMockGeneratorAsync( Assert.NotNull(fooMethod, "Foo method should be found in the SimpleType"); } - [Test] - public async Task GetGeneratedFilesAsync_SimplifiesFrameworkNamesWhenTypeHasSystemMember() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// - -#nullable disable - -using System; -using System.ComponentModel; - -namespace TestNamespace -{ - public readonly partial struct TestRole : IEquatable - { - private readonly string _value; - private const string SystemValue = "system"; - - public TestRole(string value) - { - _value = value; - } - - public static TestRole System { get; } = new TestRole(SystemValue); - - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] - public override bool Equals(object obj) => obj is TestRole other && Equals(other); - - public bool Equals(TestRole other) => string.Equals(_value, other._value, global::System.StringComparison.InvariantCultureIgnoreCase); - - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)] - public override int GetHashCode() => _value != null ? global::System.StringComparer.InvariantCultureIgnoreCase.GetHashCode(_value) : 0; - } -} -""", - typeof(EditorBrowsableAttribute).Assembly.Location); - - Assert.That(generatedText, Is.Not.Null); - Assert.That(generatedText, Does.Contain("[EditorBrowsable(EditorBrowsableState.Never)]")); - Assert.That(generatedText, Does.Contain("StringComparison.InvariantCultureIgnoreCase")); - Assert.That(generatedText, Does.Contain("StringComparer.InvariantCultureIgnoreCase")); - Assert.That(generatedText, Does.Not.Contain("System.ComponentModel.EditorBrowsableAttribute")); - Assert.That(generatedText, Does.Not.Contain("System.StringComparison")); - Assert.That(generatedText, Does.Not.Contain("System.StringComparer")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesQualificationWhenImportedNamespacesContainSameTypeName() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using First; -using Second; - -namespace TestNamespace -{ - public class Container - { - public global::First.Conflict FirstValue { get; } - public global::Second.Conflict SecondValue { get; } - } -} - -namespace First -{ - public class Conflict { } -} - -namespace Second -{ - public class Conflict { } -} -"""); - - Assert.That(generatedText, Does.Contain("First.Conflict FirstValue")); - Assert.That(generatedText, Does.Contain("Second.Conflict SecondValue")); - Assert.That(generatedText, Does.Not.Contain("public Conflict FirstValue")); - Assert.That(generatedText, Does.Not.Contain("public Conflict SecondValue")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesFrameworkQualificationWhenGeneratedModelShadowsFrameworkType() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System; - -namespace TestNamespace -{ - public class BinaryData { } - - public class Container - { - public global::System.BinaryData Payload { get; } - } -} -"""); - - Assert.That(generatedText, Does.Contain("System.BinaryData Payload")); - Assert.That(generatedText, Does.Not.Contain("public BinaryData Payload")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesQualificationWhenCurrentNamespaceTypeShadowsImportedType() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using External; - -namespace TestNamespace -{ - public class Widget { } - - public class Container - { - public global::External.Widget ExternalWidget { get; } - } -} - -namespace External -{ - public class Widget { } -} -"""); - - Assert.That(generatedText, Does.Contain("External.Widget ExternalWidget")); - Assert.That(generatedText, Does.Not.Contain("public Widget ExternalWidget")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesQualificationWhenParameterNameConflictsWithTypeName() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System; - -namespace TestNamespace -{ - public class Container - { - public bool Equals(string StringComparison) => global::System.StringComparison.InvariantCultureIgnoreCase.Equals(StringComparison, StringComparison); - } -} -"""); - - Assert.That(generatedText, Does.Contain("System.StringComparison.InvariantCultureIgnoreCase")); - Assert.That(generatedText, Does.Not.Contain("=> StringComparison.InvariantCultureIgnoreCase")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesGlobalAliasesInXmlDocCrefsSeparatelyFromCode() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System; - -namespace TestNamespace -{ - /// See . - public class Container - { - public global::System.ArgumentNullException Create() => null; - } -} -"""); - - Assert.That(generatedText, Does.Contain("")); - Assert.That(generatedText, Does.Contain("ArgumentNullException Create()")); - Assert.That(generatedText, Does.Not.Contain("global::System.ArgumentNullException")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesQualifiedXmlDocCrefs() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System; -using System.Text.Json; - -namespace TestNamespace -{ - /// See and . - /// The derived classes available for instantiation are: . - /// A new instance for mocking. - public class Widget - { - public JsonSerializerOptions Options { get; } - } -} -"""); - - Assert.That(generatedText, Does.Contain("")); - Assert.That(generatedText, Does.Contain("derived classes available for instantiation are: ")); - Assert.That(generatedText, Does.Contain("")); - Assert.That(generatedText, Does.Contain(" instance for mocking")); - Assert.That(generatedText, Does.Not.Contain("System.Text.Json.JsonSerializer")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesAliasesGenericNamesAndCustomizationTypesSafely() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System.Collections.Generic; -using AliasWidget = Customization.Widget; - -namespace TestNamespace -{ - public class Container - { - public global::System.Collections.Generic.IList Widgets { get; } - public global::Customization.Widget Create(AliasWidget widget) => widget; - public string GetFormat() => ((IPersistableModel)this).GetFormatFromOptions(null); - } - - public interface IPersistableModel - { - string GetFormatFromOptions(object options); - } -} - -namespace Customization -{ - public class Widget { } -} -"""); - - Assert.That(generatedText, Does.Contain("IList Widgets")); - Assert.That(generatedText, Does.Contain("Customization.Widget Create(AliasWidget widget)")); - Assert.That(generatedText, Does.Contain("((IPersistableModel)this).GetFormatFromOptions(null)")); - Assert.That(generatedText, Does.Not.Contain("global::System.Collections.Generic.IList")); - Assert.That(generatedText, Does.Not.Contain("global::Customization.Widget")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesQualifiedGenericTypeNames() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System.ClientModel; -using System.Threading.Tasks; - -namespace TestNamespace -{ - public class Widget { } - - public class Operations - { - public Task> GetAsync() => null; - } -} -""", - typeof(System.ClientModel.ClientResult).Assembly.Location); - - Assert.That(generatedText, Does.Contain("Task> GetAsync()")); - Assert.That(generatedText, Does.Not.Contain("System.ClientModel.ClientResult")); - Assert.That(generatedText, Does.Not.Contain("TestNamespace.Widget")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesSameNamespaceStaticMemberAccess() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -namespace TestNamespace -{ - public class Container - { - public void Invoke(string value) - { - TestNamespace.Argument.AssertNotNull(value, nameof(value)); - } - } - - internal static class Argument - { - public static void AssertNotNull(object value, string name) { } - } -} -"""); - - Assert.That(generatedText, Does.Contain("Argument.AssertNotNull(value, nameof(value));")); - Assert.That(generatedText, Does.Not.Contain("TestNamespace.Argument.AssertNotNull")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesStaticMemberQualificationWhenShortNameConflicts() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -namespace TestNamespace -{ - public class Container - { - public void Invoke(string value) - { - var Argument = new LocalArgument(); - TestNamespace.Argument.AssertNotNull(value, nameof(value)); - } - } - - internal static class Argument - { - public static void AssertNotNull(object value, string name) { } - } - - internal class LocalArgument - { - public void AssertNotNull(object value, string name) { } - } -} -"""); - - Assert.That(generatedText, Does.Contain("TestNamespace.Argument.AssertNotNull(value, nameof(value));")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesGeneratedParentheses() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -using System.Collections.Generic; - -namespace TestNamespace -{ - public class Container - { - private Dictionary _map; - private string[] _items; - private int _length; - - public Dictionary Map => (_map ??= new Dictionary()); - - public void Invoke(object writer, WidgetHolder widget, object value, object result, Byte[] bytes, Char[] chars, Options options = (Options)null, string nameHint = (String)null, global::System.String title = (global::System.String)null) - { - if (((value is ICollection collection) && (collection.Count == 0))) - { - Use(((Widget)result)); - } - - if ((_items[(collection.Count - 1)] == null)) - { - return; - } - - switch (collection.Count) - { - case ((>= 200) and (< 300)): - return; - } - - switch (value) - { - case string s when (s.Length > 0): - return; - } - - _map = _map ?? (GetMap() ?? new Dictionary()); - _length = (_length + 1); - string format = (nameHint == "W") ? nameHint : "J"; - string converted = TypeFormatters.ToString(bytes); - writer.WriteObjectValue((Widget)result, options); - widget.Value.Equals(widget.Value); - } - - private void Use(Widget widget) { } - private Dictionary GetMap() => null; - } - - public class Widget { } - public class WidgetHolder - { - public Widget Value { get; } - } - public class Options { } - - internal static class TypeFormatters - { - public static string ToString(byte[] value) => null; - public static string Invoke(byte[] value) => TypeFormatters.ToString(value); - } - - internal static class WriterExtensions - { - public static void WriteObjectValue(this object writer, T value, Options options) { } - } -} -"""); - - Assert.That(generatedText, Does.Contain("Map => _map ??= new Dictionary();")); - Assert.That(generatedText, Does.Contain("if (value is ICollection collection && collection.Count == 0)")); - Assert.That(generatedText, Does.Contain("Use((Widget)result);")); - Assert.That(generatedText, Does.Contain("_items[collection.Count - 1] == null")); - Assert.That(generatedText, Does.Contain("case >= 200 and < 300:")); - Assert.That(generatedText, Does.Contain("byte[] bytes")); - Assert.That(generatedText, Does.Contain("char[] chars")); - Assert.That(generatedText, Does.Contain("Options options = null")); - Assert.That(generatedText, Does.Contain("string nameHint = null")); - Assert.That(generatedText, Does.Contain("string title = null")); - Assert.That(generatedText, Does.Contain("case string s when s.Length > 0:")); - Assert.That(generatedText, Does.Contain("_map = _map ?? GetMap() ?? new Dictionary();")); - Assert.That(generatedText, Does.Contain("_length = _length + 1;")); - Assert.That(generatedText, Does.Contain("string format = nameHint == \"W\" ? nameHint : \"J\";")); - Assert.That(generatedText, Does.Contain("TypeFormatters.ToString(bytes);")); - Assert.That(generatedText, Does.Contain("writer.WriteObjectValue((Widget)result, options);")); - Assert.That(generatedText, Does.Contain("public static string Invoke(byte[] value) => ToString(value);")); - Assert.That(generatedText, Does.Contain("widget.Value.Equals(widget.Value);")); - Assert.That(generatedText, Does.Not.Contain("Use(((Widget)result))")); - Assert.That(generatedText, Does.Not.Contain("if ((value is ICollection collection)")); - Assert.That(generatedText, Does.Not.Contain("Byte[]")); - Assert.That(generatedText, Does.Not.Contain("Char[]")); - Assert.That(generatedText, Does.Not.Contain("(Options)null")); - Assert.That(generatedText, Does.Not.Contain("(String)null")); - } - - [Test] - public async Task GetGeneratedFilesAsync_ReducesThisQualificationSafely() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -namespace TestNamespace -{ - public class Container - { - public string Name { get; } - - public void Invoke() - { - this.Create(this.Name); - } - - private void Create(string name) { } - } -} -"""); - - Assert.That(generatedText, Does.Contain("Create(Name);")); - Assert.That(generatedText, Does.Not.Contain("this.Create")); - Assert.That(generatedText, Does.Not.Contain("this.Name")); - } - - [Test] - public async Task GetGeneratedFilesAsync_PreservesThisQualificationWhenLocalNameConflicts() - { - var generatedText = await ProcessGeneratedCodeAsync( - """ -// -#nullable disable - -namespace TestNamespace -{ - public class Container - { - public string Name { get; } - - public void Invoke(string Name) - { - this.Create(this.Name); - } - - private void Create(string name) { } - } -} -"""); - - Assert.That(generatedText, Does.Contain("Create(this.Name);")); - } - [Test] public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj() { @@ -845,27 +336,6 @@ public class Placeholder {{ }} return dllPath; } - private async Task ProcessGeneratedCodeAsync(string content, params string[] additionalMetadataReferencePaths) - { - MockHelpers.LoadMockGenerator( - outputPath: _projectDir, - configuration: "{\"package-name\": \"TestNamespace\"}", - additionalMetadataReferences: additionalMetadataReferencePaths.Select(static path => MetadataReference.CreateFromFile(path))); - - GeneratedCodeWorkspace.Initialize(); - var workspace = await GeneratedCodeWorkspace.Create(false); - await workspace.AddGeneratedFile(new CodeFile(content, "TestFile.cs")); - - string? generatedText = null; - await foreach (var generatedFile in workspace.GetGeneratedFilesAsync()) - { - generatedText = generatedFile.Text; - } - - Assert.That(generatedText, Is.Not.Null); - return generatedText!; - } - private void CreateTestAssemblyAndProjectFile(string nugetCacheDir, string csProjectFileName) { var ns = csProjectFileName.StartsWith("TestNamespaceUnevaluatedFrameworkValue") From c11ae8250b9147843022ae6d3d88feaad658c951 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 12 Jun 2026 04:22:34 +0000 Subject: [PATCH 14/16] Add provider reference map shadow replacement --- .../src/CSharpGen.cs | 2 + .../PostProcessing/GeneratedCodeWorkspace.cs | 5 + .../src/PostProcessing/PostProcessor.cs | 110 +++- .../ProviderReferenceMapShadowAnalyzer.cs | 616 ++++++++++++++++++ .../ProviderReferenceMapShadowResult.cs | 14 + 5 files changed, 716 insertions(+), 31 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index 238b16b9766..fdd3ae2191f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -121,6 +121,8 @@ await customCodeWorkspace.GetCompilationAsync(), // Add all the generated files to the workspace await MeasureGenerationStepAsync("Generation.AddGeneratedFilesToWorkspace", () => Task.WhenAll(generateFilesTasks)); + MeasureGenerationStep("Generation.ProviderReferenceMapShadowAnalysis", () => generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders)); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index bf4b08b31b5..10d66430f97 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -86,6 +86,11 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapShadowAnalyzer.Analyze(providers, _project); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index d5db7e47298..57136754da1 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -133,18 +133,35 @@ public async Task InternalizeAsync(Project project) // first get all the declared symbols var definitions = await MeasureAsync("PostProcessor.Internalize.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, true)); - // build the reference map - var referenceMap = - await MeasureAsync( - "PostProcessor.Internalize.BuildPublicReferenceMapAsync", - () => new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache)); - // get the root symbols - var rootSymbols = await MeasureAsync("PostProcessor.Internalize.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); - // traverse all the root and recursively add all the things we met - var publicSymbols = Measure("PostProcessor.Internalize.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray()); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); + IEnumerable symbolsToInternalize; + if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult) + { + symbolsToInternalize = Measure("PostProcessor.Internalize.UseShadowCandidates", () => + GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.InternalizeCandidates).ToArray()); + } + else + { + // build the reference map + var referenceMap = + await MeasureAsync( + "PostProcessor.Internalize.BuildPublicReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DeclaredNodesCache)); + // get the root symbols + var rootSymbols = await MeasureAsync("PostProcessor.Internalize.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); + // traverse all the root and recursively add all the things we met + var publicSymbols = Measure("PostProcessor.Internalize.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray()); + + symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); + } + + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult) + { + ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( + "internalize", + symbolsToInternalize.Select(static symbol => symbol.GetFullyQualifiedName()), + shadowResult.InternalizeCandidates); + } var nodesToInternalize = Measure("PostProcessor.Internalize.CollectNodes", () => { @@ -255,25 +272,45 @@ public async Task RemoveAsync(Project project) // find all the declarations, including non-public declared var definitions = await MeasureAsync("PostProcessor.Remove.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, false)); - // build reference map - var referenceMap = - await MeasureAsync( - "PostProcessor.Remove.BuildAllReferenceMapAsync", - () => new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache)); - // get root symbols - var rootSymbols = await MeasureAsync("PostProcessor.Remove.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) - rootSymbols.Add(_modelFactorySymbol); - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = Measure("PostProcessor.Remove.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray().AsEnumerable()); - - referencedSymbols = Measure("PostProcessor.Remove.AddSampleSymbols", () => AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols)); - var referencedSet = Measure("PostProcessor.Remove.BuildReferencedSet", () => new HashSet(referencedSymbols, SymbolEqualityComparer.Default)); - - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult) + { + symbolsToRemove = Measure("PostProcessor.Remove.UseShadowCandidates", () => + GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.RemoveCandidates).ToArray()); + referencedSet = Measure("PostProcessor.Remove.BuildShadowReferencedSet", () => + new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default)); + } + else + { + // build reference map + var referenceMap = + await MeasureAsync( + "PostProcessor.Remove.BuildAllReferenceMapAsync", + () => new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache)); + // get root symbols + var rootSymbols = await MeasureAsync("PostProcessor.Remove.GetRootSymbolsAsync", () => GetRootSymbolsAsync(project, definitions)); + // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal + // helpers that are required by the model factory. + if (_modelFactorySymbol != null) + rootSymbols.Add(_modelFactorySymbol); + // traverse the map to determine the declarations that we are about to remove, starting from root nodes + var referencedSymbols = Measure("PostProcessor.Remove.VisitSymbolsFromRoot", () => VisitSymbolsFromRootAsync(rootSymbols, referenceMap).ToArray().AsEnumerable()); + + referencedSymbols = Measure("PostProcessor.Remove.AddSampleSymbols", () => AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols)); + referencedSet = Measure("PostProcessor.Remove.BuildReferencedSet", () => new HashSet(referencedSymbols, SymbolEqualityComparer.Default)); + + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } + + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult) + { + ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( + "remove", + symbolsToRemove.Select(static symbol => symbol.GetFullyQualifiedName()), + shadowResult.RemoveCandidates); + } var nodesToRemove = Measure("PostProcessor.Remove.CollectNodes", () => { @@ -358,6 +395,17 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) + { + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } + } + private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) { var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs new file mode 100644 index 00000000000..7d549afb5a1 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs @@ -0,0 +1,616 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.TypeSpec.Generator +{ + internal static class ProviderReferenceMapShadowAnalyzer + { + private const string EnableEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW"; + private const string UseShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_USE_SHADOW"; + private const string OutputDirectoryEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_DIR"; + + private static ProviderReferenceMapShadowResult? _latestResult; + + public static bool IsEnabled => string.Equals( + Environment.GetEnvironmentVariable(EnableEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + + public static ProviderReferenceMapShadowResult? LatestResult => _latestResult; + + public static bool UseShadowMap => string.Equals( + Environment.GetEnvironmentVariable(UseShadowEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + + public static void Analyze(IReadOnlyList providers, Project project) + { + if (!IsEnabled) + { + _latestResult = null; + return; + } + + var graph = BuildGraph(providers); + var customRoots = GetCustomCodeGeneratedTypeRoots(project, graph.Nodes); + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false); + internalizeRoots.UnionWith(customRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, graph.References); + var internalizeHelperRoots = GetHelperRootNames(providers, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeReachable = GetReachableTypes(internalizeRoots, graph.References); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: true); + var internalizeCandidates = internalizeDeclaredNodes.Except(internalizeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var removeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: true); + removeRoots.UnionWith(customRoots); + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + var removeHelperRoots = GetHelperRootNames(providers, graph.Nodes, removeReachableWithoutHelpers); + removeRoots.UnionWith(removeHelperRoots); + var removeReachable = GetReachableTypes(removeRoots, graph.References); + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: false); + var removeCandidates = removeDeclaredNodes.Except(removeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var helperRoots = internalizeHelperRoots.Concat(removeHelperRoots).ToHashSet(StringComparer.Ordinal); + + _latestResult = new ProviderReferenceMapShadowResult( + internalizeCandidates.ToHashSet(StringComparer.Ordinal), + removeCandidates.ToHashSet(StringComparer.Ordinal)); + + WriteReport(graph, customRoots, helperRoots, internalizeRoots, internalizeReachable, internalizeCandidates, removeRoots, removeReachable, removeCandidates); + } + + private static HashSet GetCustomCodeGeneratedTypeRoots(Project project, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return roots; + } + + foreach (var document in project.Documents) + { + if (GeneratedCodeWorkspace.IsGeneratedDocument(document) || GeneratedCodeWorkspace.IsGeneratedTestDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var model = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetDeclaredSymbol(declaration) as ITypeSymbol, generatedTypeNames); + } + + foreach (var typeSyntax in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetTypeInfo(typeSyntax).Type, generatedTypeNames); + } + + foreach (var objectCreation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(objectCreation).Symbol?.ContainingType, generatedTypeNames); + } + + foreach (var invocation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(invocation).Symbol?.ContainingType, generatedTypeNames); + } + } + + return roots; + } + + private static void AddSymbolRoot(HashSet roots, ITypeSymbol? symbol, HashSet generatedTypeNames) + { + if (symbol is not INamedTypeSymbol namedType) + { + return; + } + + AddMatchingName(roots, namedType.GetFullyQualifiedName(), generatedTypeNames); + foreach (var typeArgument in namedType.TypeArguments) + { + AddSymbolRoot(roots, typeArgument, generatedTypeNames); + } + } + + public static void WriteComparisonReport(string passName, IEnumerable roslynCandidates, IEnumerable providerCandidates) + { + if (!IsEnabled) + { + return; + } + + var roslynSet = roslynCandidates.ToHashSet(StringComparer.Ordinal); + var providerSet = providerCandidates.ToHashSet(StringComparer.Ordinal); + var missingFromProvider = roslynSet.Except(providerSet, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + var extraInProvider = providerSet.Except(roslynSet, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var directory = GetOutputDirectory(); + Directory.CreateDirectory(directory); + var path = Path.Combine(directory, $"provider-reference-map-shadow-comparison-{passName}-{DateTime.UtcNow:yyyyMMddHHmmssfff}.txt"); + var builder = new StringBuilder(); + builder.AppendLine($"Provider reference map shadow comparison: {passName}"); + builder.AppendLine($"Roslyn candidates: {roslynSet.Count}"); + builder.AppendLine($"Provider candidates: {providerSet.Count}"); + builder.AppendLine($"Missing from provider: {missingFromProvider.Length}"); + builder.AppendLine($"Extra in provider: {extraInProvider.Length}"); + builder.AppendLine(); + builder.AppendLine("Missing from provider:"); + foreach (var item in missingFromProvider) + { + builder.AppendLine($" {item}"); + } + + builder.AppendLine(); + builder.AppendLine("Extra in provider:"); + foreach (var item in extraInProvider) + { + builder.AppendLine($" {item}"); + } + + File.WriteAllText(path, builder.ToString()); + CodeModelGenerator.Instance.Emitter.Debug($"Provider reference map shadow comparison written to {path}"); + } + + private static ProviderReferenceGraph BuildGraph(IReadOnlyList providers) + { + var nodes = providers + .Select(static provider => GetProviderTypeName(provider.Type)) + .ToHashSet(StringComparer.Ordinal); + var references = nodes.ToDictionary(static name => name, _ => new HashSet(StringComparer.Ordinal), StringComparer.Ordinal); + + foreach (var provider in providers) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes); + AddTypeReference(references[current], provider.BaseType, nodes); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes); + + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes); + } + + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes); + } + + foreach (var property in provider.Properties) + { + AddTypeReference(references[current], property.Type, nodes); + AddTypeReference(references[current], property.ExplicitInterface, nodes); + AddAttributes(references[current], property.Attributes, nodes); + } + + foreach (var field in provider.Fields) + { + AddTypeReference(references[current], field.Type, nodes); + AddAttributes(references[current], field.Attributes, nodes); + } + + foreach (var constructor in provider.Constructors) + { + AddSignatureReferences(references[current], constructor.Signature, nodes); + } + + foreach (var method in provider.Methods) + { + AddSignatureReferences(references[current], method.Signature, nodes); + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static HashSet GetRootNames(IReadOnlyList providers, HashSet nodes, HashSet helperRoots, bool includeModelFactory) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (provider.Name.EndsWith("Client", StringComparison.Ordinal) || + IsKept(provider.Type, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + foreach (var root in generator.TypeFactory.UnionVariantTypesToKeep) + { + AddMatchingName(roots, root, nodes); + } + + foreach (var root in generator.AdditionalRootTypes) + { + AddMatchingName(roots, root, nodes); + } + + return roots; + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var excludedNames = generator.NonRootTypes; + return providers + .Where(provider => !IsModelFactoryProvider(provider)) + .Where(provider => !publicOnly || provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + .Select(provider => GetProviderTypeName(provider.Type)) + .Where(name => nodes.Contains(name)) + .Where(name => !excludedNames.Contains(name) && !excludedNames.Contains(GetSimpleName(name))) + .ToHashSet(StringComparer.Ordinal); + } + + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) => + roots.Contains(type.Name) || roots.Contains(GetProviderTypeName(type)) && nodes.Contains(GetProviderTypeName(type)); + + private static bool IsModelFactoryProvider(TypeProvider provider) => provider.GetType().Name == "ModelFactoryProvider"; + + private static HashSet GetHelperRootNames(IReadOnlyList providers, HashSet nodes, HashSet reachableTypes) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in providers) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + if (type.IsDictionary) + { + AddMatchingName(roots, "ChangeTrackingDictionary", nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + foreach (var node in nodes) + { + if (string.Equals(StripGenericArity(GetSimpleName(node)), name, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static void AddSignatureReferences(HashSet references, MethodSignatureBase signature, HashSet nodes) + { + AddTypeReference(references, signature.ReturnType, nodes); + AddAttributes(references, signature.Attributes, nodes); + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes); + AddAttributes(references, parameter.Attributes, nodes); + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes); + } + } + + private static void AddAttributes(HashSet references, IReadOnlyList attributes, HashSet nodes) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes); + } + } + + private static void AddTypeReference(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + } + + AddTypeReference(references, type.BaseType, nodes); + AddTypeReference(references, type.DeclaringType, nodes); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes); + } + } + + private static void WriteReport( + ProviderReferenceGraph graph, + HashSet customRoots, + HashSet helperRoots, + HashSet internalizeRoots, + HashSet internalizeReachable, + IReadOnlyList internalizeCandidates, + HashSet removeRoots, + HashSet removeReachable, + IReadOnlyList removeCandidates) + { + var directory = GetOutputDirectory(); + + Directory.CreateDirectory(directory); + var path = Path.Combine(directory, $"provider-reference-map-shadow-{DateTime.UtcNow:yyyyMMddHHmmssfff}.txt"); + var builder = new StringBuilder(); + builder.AppendLine("Provider reference map shadow report"); + builder.AppendLine($"Declared providers: {graph.Nodes.Count}"); + builder.AppendLine($"Internalize roots: {internalizeRoots.Count}"); + builder.AppendLine($"Internalize reachable: {internalizeReachable.Count}"); + builder.AppendLine($"Internalize candidates: {internalizeCandidates.Count}"); + builder.AppendLine($"Custom roots: {customRoots.Count}"); + builder.AppendLine($"Helper roots: {helperRoots.Count}"); + builder.AppendLine($"Remove roots: {removeRoots.Count}"); + builder.AppendLine($"Remove reachable: {removeReachable.Count}"); + builder.AppendLine($"Remove candidates: {removeCandidates.Count}"); + builder.AppendLine(); + builder.AppendLine("Custom roots:"); + foreach (var root in customRoots.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" {root}"); + } + + builder.AppendLine(); + builder.AppendLine("Helper roots:"); + foreach (var root in helperRoots.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" {root}"); + } + + builder.AppendLine(); + builder.AppendLine("Internalize roots:"); + foreach (var root in internalizeRoots.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" {root}"); + } + + builder.AppendLine(); + builder.AppendLine("Internalize candidates:"); + foreach (var candidate in internalizeCandidates) + { + builder.AppendLine($" {candidate}"); + } + + builder.AppendLine(); + builder.AppendLine("Remove roots:"); + foreach (var root in removeRoots.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" {root}"); + } + + builder.AppendLine(); + builder.AppendLine("Remove candidates:"); + foreach (var candidate in removeCandidates) + { + builder.AppendLine($" {candidate}"); + } + + builder.AppendLine(); + builder.AppendLine("References:"); + foreach (var (type, references) in graph.References.OrderBy(static item => item.Key, StringComparer.Ordinal)) + { + builder.AppendLine($" {type}"); + foreach (var reference in references.OrderBy(static name => name, StringComparer.Ordinal)) + { + builder.AppendLine($" -> {reference}"); + } + } + + File.WriteAllText(path, builder.ToString()); + CodeModelGenerator.Instance.Emitter.Debug($"Provider reference map shadow report written to {path}"); + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + + private static string GetOutputDirectory() + { + var directory = Environment.GetEnvironmentVariable(OutputDirectoryEnvironmentVariable); + return string.IsNullOrWhiteSpace(directory) + ? Path.Combine(Path.GetTempPath(), "typespec-provider-reference-map-shadow") + : Path.GetFullPath(directory); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs new file mode 100644 index 00000000000..14b68f7d187 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapShadowResult( + HashSet InternalizeCandidates, + HashSet RemoveCandidates) + { + public static ProviderReferenceMapShadowResult Empty { get; } = new([], []); + } +} From a5f8d403fd53a6c15447016d678ee1a14cd35b63 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 12 Jun 2026 08:37:46 +0000 Subject: [PATCH 15/16] Update provider reference map benchmarks --- .../perf/FullGenerationBenchmark.cs | 20 +++ .../src/PostProcessing/PostProcessor.cs | 12 +- .../ProviderReferenceMapShadowAnalyzer.cs | 115 +++++++++++++++++- .../ProviderReferenceMapShadowResult.cs | 3 +- 4 files changed, 142 insertions(+), 8 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs index 66553e08c38..35344ec5216 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/perf/FullGenerationBenchmark.cs @@ -14,9 +14,15 @@ public class FullGenerationBenchmark { private const string ProfileEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_STEPS"; private const string ProfileOutputDirectoryEnvironmentVariable = "POSTPROCESSING_BENCHMARK_PROFILE_DIR"; + private const string ShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW"; + private const string UseShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_USE_SHADOW"; + private const string ShadowReportEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_REPORT"; private bool _profileSteps; + [Params(false, true)] + public bool UseProviderReferenceMap { get; set; } + [GlobalSetup] public void GlobalSetup() { @@ -35,9 +41,13 @@ public async Task GenerateSampleTypeSpecProject() CSharpGen.GenerationProfile = generationProfile; var benchmarkDirectory = CreateBenchmarkInputDirectory(); + var previousShadow = Environment.GetEnvironmentVariable(ShadowEnvironmentVariable); + var previousUseShadow = Environment.GetEnvironmentVariable(UseShadowEnvironmentVariable); + var previousShadowReport = Environment.GetEnvironmentVariable(ShadowReportEnvironmentVariable); var stopwatch = Stopwatch.StartNew(); try { + SetProviderReferenceMapEnvironment(); CodeModelGenerator.Instance = new BenchmarkCodeModelGenerator(benchmarkDirectory); CodeModelGenerator.Instance.Configure(); @@ -72,10 +82,20 @@ public async Task GenerateSampleTypeSpecProject() CSharpGen.GenerationProfile = null; GeneratedCodeWorkspace.PostProcessingProfile = null; + Environment.SetEnvironmentVariable(ShadowEnvironmentVariable, previousShadow); + Environment.SetEnvironmentVariable(UseShadowEnvironmentVariable, previousUseShadow); + Environment.SetEnvironmentVariable(ShadowReportEnvironmentVariable, previousShadowReport); TryDeleteDirectory(benchmarkDirectory); } } + private void SetProviderReferenceMapEnvironment() + { + Environment.SetEnvironmentVariable(ShadowEnvironmentVariable, UseProviderReferenceMap ? "true" : null); + Environment.SetEnvironmentVariable(UseShadowEnvironmentVariable, UseProviderReferenceMap ? "true" : null); + Environment.SetEnvironmentVariable(ShadowReportEnvironmentVariable, null); + } + private static void WriteProfile(GeneratedCodeWorkspacePostProcessingProfile profile, string fileName, string header) { var profileDirectory = GetProfileOutputDirectory(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index 57136754da1..13fbf1af354 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -134,7 +134,9 @@ public async Task InternalizeAsync(Project project) // first get all the declared symbols var definitions = await MeasureAsync("PostProcessor.Internalize.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, true)); IEnumerable symbolsToInternalize; - if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult) + if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && + ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult && + useShadowResult.ProjectId == project.Id) { symbolsToInternalize = Measure("PostProcessor.Internalize.UseShadowCandidates", () => GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.InternalizeCandidates).ToArray()); @@ -155,7 +157,7 @@ await MeasureAsync( symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); } - if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult) + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult && shadowResult.ProjectId == project.Id) { ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( "internalize", @@ -274,7 +276,9 @@ public async Task RemoveAsync(Project project) var definitions = await MeasureAsync("PostProcessor.Remove.GetTypeSymbolsAsync", () => GetTypeSymbolsAsync(compilation, project, false)); IEnumerable symbolsToRemove; HashSet referencedSet; - if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult) + if (ProviderReferenceMapShadowAnalyzer.UseShadowMap && + ProviderReferenceMapShadowAnalyzer.LatestResult is { } useShadowResult && + useShadowResult.ProjectId == project.Id) { symbolsToRemove = Measure("PostProcessor.Remove.UseShadowCandidates", () => GetSymbolsByName(definitions.DeclaredSymbols, useShadowResult.RemoveCandidates).ToArray()); @@ -304,7 +308,7 @@ await MeasureAsync( symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); } - if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult) + if (ProviderReferenceMapShadowAnalyzer.LatestResult is { } shadowResult && shadowResult.ProjectId == project.Id) { ProviderReferenceMapShadowAnalyzer.WriteComparisonReport( "remove", diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs index 7d549afb5a1..1196f44ed74 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs @@ -11,6 +11,7 @@ using Microsoft.TypeSpec.Generator.Statements; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.FindSymbols; namespace Microsoft.TypeSpec.Generator { @@ -18,6 +19,7 @@ internal static class ProviderReferenceMapShadowAnalyzer { private const string EnableEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW"; private const string UseShadowEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_USE_SHADOW"; + private const string ReportEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_REPORT"; private const string OutputDirectoryEnvironmentVariable = "TYPESPEC_PROVIDER_REFERENCE_MAP_SHADOW_DIR"; private static ProviderReferenceMapShadowResult? _latestResult; @@ -34,6 +36,11 @@ internal static class ProviderReferenceMapShadowAnalyzer "true", StringComparison.OrdinalIgnoreCase); + private static bool ShouldWriteReports => string.Equals( + Environment.GetEnvironmentVariable(ReportEnvironmentVariable), + "true", + StringComparison.OrdinalIgnoreCase); + public static void Analyze(IReadOnlyList providers, Project project) { if (!IsEnabled) @@ -53,6 +60,8 @@ public static void Analyze(IReadOnlyList providers, Project projec var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: true); var internalizeCandidates = internalizeDeclaredNodes.Except(internalizeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + AddGeneratedBodyReferences(project, providers, graph); + var removeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: true); removeRoots.UnionWith(customRoots); var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); @@ -65,10 +74,14 @@ public static void Analyze(IReadOnlyList providers, Project projec var helperRoots = internalizeHelperRoots.Concat(removeHelperRoots).ToHashSet(StringComparer.Ordinal); _latestResult = new ProviderReferenceMapShadowResult( + project.Id, internalizeCandidates.ToHashSet(StringComparer.Ordinal), removeCandidates.ToHashSet(StringComparer.Ordinal)); - WriteReport(graph, customRoots, helperRoots, internalizeRoots, internalizeReachable, internalizeCandidates, removeRoots, removeReachable, removeCandidates); + if (ShouldWriteReports) + { + WriteReport(graph, customRoots, helperRoots, internalizeRoots, internalizeReachable, internalizeCandidates, removeRoots, removeReachable, removeCandidates); + } } private static HashSet GetCustomCodeGeneratedTypeRoots(Project project, HashSet generatedTypeNames) @@ -134,7 +147,7 @@ private static void AddSymbolRoot(HashSet roots, ITypeSymbol? symbol, Ha public static void WriteComparisonReport(string passName, IEnumerable roslynCandidates, IEnumerable providerCandidates) { - if (!IsEnabled) + if (!IsEnabled || !ShouldWriteReports) { return; } @@ -232,6 +245,99 @@ private static ProviderReferenceGraph BuildGraph(IReadOnlyList pro return new ProviderReferenceGraph(nodes, references); } + private static void AddGeneratedBodyReferences(Project project, IReadOnlyList providers, ProviderReferenceGraph graph) + { + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return; + } + + foreach (var provider in providers) + { + if (!IsGeneratedBodyReferenceCandidate(provider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + + var symbol = compilation.GetTypeByMetadataName(providerName); + if (symbol == null) + { + continue; + } + + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, symbol); + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + foreach (var method in symbol.GetMembers().OfType()) + { + if (method.IsExtensionMethod) + { + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, method); + } + } + } + } + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith("/Internal/ClientUriBuilder.cs", StringComparison.Ordinal) || + relativePath.Contains("/CollectionResults/", StringComparison.Ordinal); + } + + private static void AddGeneratedReferencesToHelper(Project project, Compilation compilation, ProviderReferenceGraph graph, string helperName, ISymbol symbol) + { + foreach (var reference in SymbolFinder.FindReferencesAsync(symbol, project.Solution).GetAwaiter().GetResult()) + { + foreach (var location in reference.Locations) + { + var document = location.Document; + if (!GeneratedCodeWorkspace.IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var node = root.FindNode(location.Location.SourceSpan); + var owner = node.AncestorsAndSelf().OfType().FirstOrDefault(); + if (owner == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(owner.SyntaxTree); + if (semanticModel.GetDeclaredSymbol(owner) is not INamedTypeSymbol ownerSymbol) + { + continue; + } + + var ownerName = ownerSymbol.GetFullyQualifiedName(); + if (graph.Nodes.Contains(ownerName)) + { + graph.References[ownerName].Add(helperName); + } + } + } + } + private static HashSet GetRootNames(IReadOnlyList providers, HashSet nodes, HashSet helperRoots, bool includeModelFactory) { var generator = CodeModelGenerator.Instance; @@ -241,7 +347,7 @@ private static HashSet GetRootNames(IReadOnlyList provider foreach (var provider in providers) { var name = GetProviderTypeName(provider.Type); - if (provider.Name.EndsWith("Client", StringComparison.Ordinal) || + if (IsClientProviderRoot(provider) || IsKept(provider.Type, generator.AdditionalRootTypes, nodes) || includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || includeModelFactory && helperRoots.Contains(name)) @@ -279,6 +385,9 @@ private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList roots, HashSet nodes) => roots.Contains(type.Name) || roots.Contains(GetProviderTypeName(type)) && nodes.Contains(GetProviderTypeName(type)); + private static bool IsClientProviderRoot(TypeProvider provider) => + provider.RelativeFilePath.EndsWith("Client.cs", StringComparison.Ordinal); + private static bool IsModelFactoryProvider(TypeProvider provider) => provider.GetType().Name == "ModelFactoryProvider"; private static HashSet GetHelperRootNames(IReadOnlyList providers, HashSet nodes, HashSet reachableTypes) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs index 14b68f7d187..48a4038183d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowResult.cs @@ -2,13 +2,14 @@ // Licensed under the MIT License. using System.Collections.Generic; +using Microsoft.CodeAnalysis; namespace Microsoft.TypeSpec.Generator { internal sealed record ProviderReferenceMapShadowResult( + ProjectId ProjectId, HashSet InternalizeCandidates, HashSet RemoveCandidates) { - public static ProviderReferenceMapShadowResult Empty { get; } = new([], []); } } From df00d674eba31fcffb24f1055cbd1da05fb205eb Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 12 Jun 2026 09:27:07 +0000 Subject: [PATCH 16/16] Track serialization providers in benchmark reference map --- .../ProviderReferenceMapShadowAnalyzer.cs | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs index 1196f44ed74..ffa8ac5b8f5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapShadowAnalyzer.cs @@ -186,12 +186,13 @@ public static void WriteComparisonReport(string passName, IEnumerable ro private static ProviderReferenceGraph BuildGraph(IReadOnlyList providers) { - var nodes = providers + var generatedProviders = GetGeneratedProviders(providers); + var nodes = generatedProviders .Select(static provider => GetProviderTypeName(provider.Type)) .ToHashSet(StringComparer.Ordinal); var references = nodes.ToDictionary(static name => name, _ => new HashSet(StringComparer.Ordinal), StringComparer.Ordinal); - foreach (var provider in providers) + foreach (var provider in generatedProviders) { var current = GetProviderTypeName(provider.Type); AddTypeReference(references[current], provider.Type, nodes); @@ -245,6 +246,18 @@ private static ProviderReferenceGraph BuildGraph(IReadOnlyList pro return new ProviderReferenceGraph(nodes, references); } + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + generatedProviders.Add(provider); + generatedProviders.AddRange(provider.SerializationProviders); + } + + return generatedProviders; + } + private static void AddGeneratedBodyReferences(Project project, IReadOnlyList providers, ProviderReferenceGraph graph) { var compilation = project.GetCompilationAsync().GetAwaiter().GetResult();