Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,17 @@ private static bool IsMatch(TypeProvider enclosingType, MethodSignatureBase sign
for (int i = 0; i < parameterTypes.Length; i++)
{
var parameterType = ((ITypeSymbol)parameterTypes[i]!).GetCSharpType();
var signatureParamType = signature.Parameters[i].Type;

// If the parameter type is a generic type alias, resolve to the constraint type for matching.
if (signature is MethodSignature methodSig)
{
signatureParamType = ResolveGenericConstraintType(signatureParamType, methodSig);
}

// we ignore nullability for reference types as these are generated the same regardless of nullability
if (!IsNameMatch(parameterType, signature.Parameters[i].Type) ||
(parameterType.IsValueType && parameterType.IsNullable != signature.Parameters[i].Type.IsNullable))
if (!IsNameMatch(parameterType, signatureParamType) ||
(parameterType.IsValueType && parameterType.IsNullable != signatureParamType.IsNullable))
{
return false;
}
Expand All @@ -734,6 +742,40 @@ private static bool IsMatch(TypeProvider enclosingType, MethodSignatureBase sign
return true;
}

private static CSharpType ResolveGenericConstraintType(CSharpType paramType, MethodSignature methodSignature)
{
if (methodSignature.GenericArguments is null || methodSignature.GenericParameterConstraints is null)
{
return paramType;
}

foreach (var genericArg in methodSignature.GenericArguments)
{
if (genericArg.Name != paramType.Name)
{
continue;
}

foreach (var whereExpr in methodSignature.GenericParameterConstraints)
{
if (whereExpr.Type.Name != paramType.Name)
{
continue;
}

foreach (var constraint in whereExpr.Constraints)
{
if (constraint is TypeReferenceExpression { Type: { } constraintType })
{
return constraintType;
}
}
}
}

return paramType;
}

private static bool IsNameMatch(CSharpType typeFromCustomization, CSharpType generatedType)
{
// The namespace may not be available for generated types referenced from customization as they
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Snippets;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;

namespace Microsoft.TypeSpec.Generator.Tests.Providers.ModelProviders
{
Expand Down Expand Up @@ -361,6 +363,43 @@ protected override TypeProvider[] BuildTypeProviders()
}
}

[Test]
public async Task CanRemoveMethodWithGenericTypeConstraintParameter()
{
var client = new ClientTypeProvider();
var outputLibrary = new ClientOutputLibrary(client);
var mockGenerator = await MockHelpers.LoadMockGeneratorAsync(
createOutputLibrary: () => outputLibrary,
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

// Create a generic type parameter T
var tType = typeof(GenericTemplate<>).GetGenericArguments()[0];

var methods = new[]
{
new MethodProvider(new MethodSignature(
"Method1",
$"",
MethodSignatureModifiers.Public,
null,
$"",
[
new ParameterProvider("param1", $"", tType)
],
GenericArguments: [tType],
GenericParameterConstraints: [Where.Implements(tType, typeof(IDisposable))]),
Snippet.ThrowExpression(Snippet.Null), client),
};
client.MethodProviders = methods;

var csharpGen = new CSharpGen();
await csharpGen.ExecuteAsync();

Assert.AreEqual(0, mockGenerator.Object.OutputLibrary.TypeProviders.Single(t => t.Name == "MockInputClient").Methods.Count);
}

private class GenericTemplate<T> { }

private class ClientTypeProvider : TypeProvider
{
public MethodProvider[] MethodProviders { get; set; } = [];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System;
using Microsoft.TypeSpec.Generator.Customizations;

namespace Sample
{
[CodeGenSuppress("Method1", typeof(IDisposable))]
public partial class MockInputClient
{
}
}