Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ internal class PostProcessor
private readonly HashSet<string> _typesToKeep;
private INamedTypeSymbol? _modelFactorySymbol;

// CS8019: Unnecessary using directive.
private const string UnnecessaryUsingDirectiveDiagnosticId = "CS8019";

// Diagnostics that indicate a type, namespace, or name could not be resolved. When any of these are
// present in a document the compiler cannot reliably determine whether a using directive is used, so
// the CS8019 ("unnecessary using") diagnostic may be a false positive and must not be trusted.
private static readonly HashSet<string> UnresolvedReferenceDiagnosticIds =
[
"CS0103", // The name does not exist in the current context.
"CS0234", // The type or namespace name does not exist in the namespace (missing assembly reference?).
"CS0246", // The type or namespace name could not be found (missing using or assembly reference?).
"CS0305", // Using the generic type requires type arguments.
"CS0308", // The non-generic type cannot be used with type arguments.
];

public PostProcessor(
HashSet<string> typesToKeep,
string? modelFactoryFullName = null,
Expand Down Expand Up @@ -156,20 +171,39 @@ public async Task<Project> InternalizeAsync(Project project)
project = MarkInternal(project, model, documentId);
}

if (nodesToInternalize.Count > 0)
{
CodeModelGenerator.Instance.Emitter.Info(
$"Internalized {nodesToInternalize.Count} unreferenced public type(s).");
}

var modelNamesToRemove =
nodesToInternalize.Keys.Select(item => item.Identifier.Text);
project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet());
DocumentId? modelFactoryDocumentId;
(project, modelFactoryDocumentId) = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet());

if (nodesToInternalize.Count > 0)
{
var documentsToClean = nodesToInternalize.Values.ToHashSet();
// Removing methods from the model factory can leave a using directive (for a model in a
// different namespace) unused, so include the model factory document in the cleanup pass.
if (modelFactoryDocumentId != null)
{
documentsToClean.Add(modelFactoryDocumentId);
}
project = await RemoveUnusedUsingsAsync(project, documentsToClean);
}

return project;
}

private async Task<Project> RemoveMethodsFromModelFactoryAsync(Project project,
private async Task<(Project Project, DocumentId? ModelFactoryDocumentId)> RemoveMethodsFromModelFactoryAsync(Project project,
TypeSymbols definitions,
HashSet<string> namesToRemove)
{
var modelFactorySymbol = definitions.ModelFactorySymbol;
if (modelFactorySymbol == null)
return project;
return (project, null);

var nodesToRemove = new List<SyntaxNode>();

Expand Down Expand Up @@ -203,7 +237,7 @@ private async Task<Project> RemoveMethodsFromModelFactoryAsync(Project project,

// maybe this is possible, for instance, we could be adding the customization all entries previously inside the generated model factory so that the generated model factory is empty and removed.
if (modelFactoryGeneratedDocument == null)
return project;
return (project, null);

var root = await modelFactoryGeneratedDocument.GetSyntaxRootAsync();
Debug.Assert(root is not null);
Expand All @@ -214,10 +248,10 @@ private async Task<Project> RemoveMethodsFromModelFactoryAsync(Project project,
var methods = root.DescendantNodes().OfType<MethodDeclarationSyntax>();
if (!methods.Any())
{
return project.RemoveDocument(modelFactoryGeneratedDocument.Id);
return (project.RemoveDocument(modelFactoryGeneratedDocument.Id), null);
}

return modelFactoryGeneratedDocument.Project;
return (modelFactoryGeneratedDocument.Project, modelFactoryGeneratedDocument.Id);
}

/// <summary>
Expand Down Expand Up @@ -336,6 +370,9 @@ private static IEnumerable<T> GetReferencedTypes<T>(T definition,

private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId)
{
CodeModelGenerator.Instance.Emitter.Debug(
$"Internalizing unreferenced public type '{declarationNode.Identifier.Text}'.");

var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword);
var tree = declarationNode.SyntaxTree;
var document = project.GetDocument(documentId)!;
Expand Down Expand Up @@ -432,6 +469,76 @@ private async Task<Project> RemoveInvalidRefs(Project project)
return solution.GetProject(project.Id)!;
}

private async Task<Project> RemoveUnusedUsingsAsync(Project project, IEnumerable<DocumentId> documentIds)
{
var solution = project.Solution;
foreach (var documentId in documentIds)
{
solution = await RemoveUnusedUsings(solution, documentId);
}

return solution.GetProject(project.Id) ?? project;
}

private async Task<Solution> RemoveUnusedUsings(Solution solution, DocumentId documentId)
{
var document = solution.GetDocument(documentId);
if (document == null)
{
return solution;
}

document = await Simplifier.ReduceAsync(document);

var root = await document.GetSyntaxRootAsync();
var model = await document.GetSemanticModelAsync();

if (root is not CompilationUnitSyntax cu || model == null)
Comment thread
jorgerangel-msft marked this conversation as resolved.
{
return solution;
}

var diagnostics = model.GetDiagnostics();

if (diagnostics.Any(d => UnresolvedReferenceDiagnosticIds.Contains(d.Id)))
{
return solution;
}

var unusedUsings = diagnostics
.Where(d => d.Id == UnnecessaryUsingDirectiveDiagnosticId)
.Select(d => cu.FindNode(d.Location.SourceSpan).FirstAncestorOrSelf<UsingDirectiveSyntax>())
.OfType<UsingDirectiveSyntax>()
.ToList();

if (unusedUsings.Count == 0)
{
return solution;
}

// Preserve any leading trivia on the first using directive (such as the file header and the
// #nullable directive) when that directive is removed, by carrying it over to the node that
// follows the removed directives.
var firstUsing = cu.Usings.FirstOrDefault();
var leadingTrivia = firstUsing is not null && unusedUsings.Contains(firstUsing)
? firstUsing.GetLeadingTrivia()
: default;

var updatedRoot = cu.RemoveNodes(unusedUsings, SyntaxRemoveOptions.KeepNoTrivia);
if (updatedRoot == null)
{
return solution;
}

if (leadingTrivia.Count > 0)
{
var firstToken = updatedRoot.GetFirstToken();
updatedRoot = updatedRoot.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(leadingTrivia.AddRange(firstToken.LeadingTrivia)));
}

return solution.WithDocumentSyntaxRoot(documentId, updatedRoot);
}

private async Task<Solution> RemoveInvalidUsings(Solution solution, DocumentId documentId)
{
var document = solution.GetDocument(documentId)!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;
Expand Down Expand Up @@ -289,11 +290,156 @@ public async Task DoesNotRemoveValidAttributes()
Assert.AreEqual(Helpers.GetExpectedFromFile().TrimEnd(), output, "The output should match the expected content.");
}

[Test]
public async Task RemovesUnusedUsingFromModelFactoryWhenInternalizing()
{
MockHelpers.LoadMockGenerator();
var workspace = new AdhocWorkspace();
var projectInfo = ProjectInfo.Create(
ProjectId.CreateNewId(),
VersionStamp.Create(),
name: "TestProj",
assemblyName: "TestProj",
language: LanguageNames.CSharp)
.WithMetadataReferences(new[]
{
MetadataReference.CreateFromFile(typeof(object).Assembly.Location)
});

var project = workspace.AddProject(projectInfo);
var folder = Helpers.GetAssetFileOrDirectoryPath(false);
const string rootFileName = "ModelFactoryRoot.cs";
string[] modelFileNames =
[
"KeptModel.cs",
"OtherNamespaceModel.cs"
];
foreach (var fileName in modelFileNames)
{
project = project.AddDocument(
fileName,
File.ReadAllText(Path.Join(folder, fileName))).Project;
}
// The model factory lives in a generated document so the post-processor can rewrite it.
const string modelFactoryFileName = "SampleModelFactory.cs";
project = project.AddDocument(
modelFactoryFileName,
File.ReadAllText(Path.Join(folder, modelFactoryFileName)),
folders: ["Generated"]).Project;
project = project.AddDocument(
rootFileName,
File.ReadAllText(Path.Join(folder, rootFileName))).Project;
var postProcessor = new TestPostProcessor(rootFileName, modelFactoryFullName: "Sample.SampleModelFactory");

var resultProject = await postProcessor.InternalizeAsync(project);

// The model in the other namespace is unreferenced and is internalized.
var otherNamespaceModel = await GetSingleClassAsync(resultProject, "OtherNamespaceModel.cs", "OtherNamespaceModel");
Assert.IsTrue(otherNamespaceModel.Modifiers.Any(m => m.IsKind(SyntaxKind.InternalKeyword)));

// The referenced model stays public and keeps its model factory method.
var keptModel = await GetSingleClassAsync(resultProject, "KeptModel.cs", "KeptModel");
Assert.IsTrue(keptModel.Modifiers.Any(m => m.IsKind(SyntaxKind.PublicKeyword)));

var modelFactory = await GetSingleClassAsync(resultProject, modelFactoryFileName, "SampleModelFactory");
var methodNames = modelFactory.Members
.OfType<MethodDeclarationSyntax>()
.Select(m => m.Identifier.Text)
.ToList();
Assert.IsTrue(methodNames.Contains("KeptModel"), "The model factory method for the referenced model should be preserved.");
Assert.IsFalse(methodNames.Contains("OtherNamespaceModel"), "The model factory method for the internalized model should be removed.");

// Removing the model factory method for the internalized model leaves the using for its namespace unused, so it is removed.
Assert.IsFalse(
await HasUsingAsync(resultProject, modelFactoryFileName, "Sample.Models"),
"Unused using for the internalized model's namespace should be removed from the model factory.");

Assert.AreEqual(
Helpers.GetExpectedFromFile().TrimEnd(),
(await GetDocumentTextAsync(resultProject, modelFactoryFileName)).TrimEnd(),
"The generated model factory output should match the expected content.");
}

[Test]
public async Task KeepsUsedUsingWhenTypeReferencesAreUnresolved()
{
MockHelpers.LoadMockGenerator();
var workspace = new AdhocWorkspace();
// Intentionally omit the System.ClientModel reference so that the BinaryContent/ClientResult
// type references in the serialization document cannot be resolved. This mirrors the post-processing
// compilation in real generation, where not every external assembly is referenced.
var projectInfo = ProjectInfo.Create(
ProjectId.CreateNewId(),
VersionStamp.Create(),
name: "TestProj",
assemblyName: "TestProj",
language: LanguageNames.CSharp)
.WithMetadataReferences(new[]
{
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(typeof(System.BinaryData).Assembly.Location)
});

var project = workspace.AddProject(projectInfo);
var folder = Helpers.GetAssetFileOrDirectoryPath(false);
const string rootFileName = "Root.cs";
const string serializationFileName = "UnreferencedModel.Serialization.cs";
foreach (var fileName in new[] { "UnreferencedModel.cs", serializationFileName })
{
project = project.AddDocument(fileName, File.ReadAllText(Path.Join(folder, fileName)), folders: ["Generated"]).Project;
}
project = project.AddDocument(rootFileName, File.ReadAllText(Path.Join(folder, rootFileName))).Project;
var postProcessor = new TestPostProcessor(rootFileName);

var resultProject = await postProcessor.InternalizeAsync(project);

// The unreferenced model is internalized.
var model = await GetSingleClassAsync(resultProject, "UnreferencedModel.cs", "UnreferencedModel");
Assert.IsTrue(model.Modifiers.Any(m => m.IsKind(SyntaxKind.InternalKeyword)));

// The using is still referenced by the conversion operators, but the type references cannot be
// resolved in this compilation. The using must be preserved rather than removed by the CS8019 pass.
Assert.IsTrue(
await HasUsingAsync(resultProject, serializationFileName, "System.ClientModel"),
"A used using directive must not be removed when its type references cannot be resolved.");

Assert.AreEqual(
Helpers.GetExpectedFromFile().TrimEnd(),
(await GetDocumentTextAsync(resultProject, serializationFileName)).TrimEnd(),
"The serialization document should be left untouched when references are unresolved.");
}

private static async Task<string> GetDocumentTextAsync(Project project, string fileName)
{
var doc = project.Documents.Single(d => d.Name == fileName);
var root = await doc.GetSyntaxRootAsync();
return root!.ToFullString();
}

private static async Task<bool> HasUsingAsync(Project project, string fileName, string usingName)
{
var doc = project.Documents.Single(d => d.Name == fileName);
var root = await doc.GetSyntaxRootAsync();
return ((CompilationUnitSyntax)root!)
.Usings
.Any(u => u.Name?.ToString() == usingName);
}

private static async Task<ClassDeclarationSyntax> GetSingleClassAsync(Project project, string fileName, string className)
{
var doc = project.Documents.Single(d => d.Name == fileName);
var root = await doc.GetSyntaxRootAsync();
return ((CompilationUnitSyntax)root!)
.DescendantNodes()
.OfType<ClassDeclarationSyntax>()
.Single(t => t.Identifier.Text == className);
}

private class TestPostProcessor : PostProcessor
{
private readonly string _rootFile;

public TestPostProcessor(string rootFile, IEnumerable<string>? nonRootTypes = null) : base([], additionalNonRootTypeNames: nonRootTypes)
public TestPostProcessor(string rootFile, IEnumerable<string>? nonRootTypes = null, string? modelFactoryFullName = null) : base([], modelFactoryFullName: modelFactoryFullName, additionalNonRootTypeNames: nonRootTypes)
{
_rootFile = rootFile;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System.ClientModel;

namespace Sample.Models
{
internal partial class UnreferencedModel
{
public static implicit operator BinaryContent(UnreferencedModel model)
{
return BinaryContent.Create(model);
}

public static explicit operator UnreferencedModel(ClientResult result)
{
PipelineResponse response = result.GetRawResponse();
return new UnreferencedModel();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Sample
{
public class Root
{
public int Value { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System.ClientModel;

namespace Sample.Models
{
public partial class UnreferencedModel
{
public static implicit operator BinaryContent(UnreferencedModel model)
{
return BinaryContent.Create(model);
}

public static explicit operator UnreferencedModel(ClientResult result)
{
PipelineResponse response = result.GetRawResponse();
return new UnreferencedModel();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Sample.Models
{
public partial class UnreferencedModel
{
public int X { get; set; }
}
}
Loading
Loading