diff --git a/src/StackExchange.Redis/ResultProcessor.cs b/src/StackExchange.Redis/ResultProcessor.cs index fc5c3d5b4..f7f6047b1 100644 --- a/src/StackExchange.Redis/ResultProcessor.cs +++ b/src/StackExchange.Redis/ResultProcessor.cs @@ -2894,10 +2894,18 @@ public override bool SetResult(PhysicalConnection connection, Message message, i if (connection.Protocol is null) { - // if we didn't get a valid response from HELLO, then we have to assume RESP2 at some point + // If we didn't get a valid response from HELLO, then we have to assume RESP2 at some point. + // We need the protocol assigned before OnFullyEstablished so that the + // protocol is reliably known *before* we do next-steps. connection.SetProtocol(RedisProtocol.Resp2); } + if (final & establishConnection) + { + // This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection. + connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}"); + } + return final; } @@ -2939,11 +2947,6 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes } if (happy) { - if (establishConnection) - { - // This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection. - connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}"); - } SetResult(message, happy); return true; } diff --git a/src/StackExchange.Redis/ServerEndPoint.cs b/src/StackExchange.Redis/ServerEndPoint.cs index abe8d8afb..9fdf0787f 100644 --- a/src/StackExchange.Redis/ServerEndPoint.cs +++ b/src/StackExchange.Redis/ServerEndPoint.cs @@ -695,7 +695,8 @@ internal void OnFullyEstablished(PhysicalConnection connection, string source) // Clear the unselectable flag ASAP since we are open for business ClearUnselectable(UnselectableFlags.DidNotRespond); - bool isResp3 = KnowOrAssumeResp3(); + // is *this specific* connection using RESP3? (without reference to config preferences) + bool isResp3 = connection?.Protocol is >= RedisProtocol.Resp3; if (bridge == subscription || isResp3) { // Note: this MUST be fire and forget, because we might be in the middle of a Sync processing @@ -703,6 +704,11 @@ internal void OnFullyEstablished(PhysicalConnection connection, string source) // Since we're issuing commands inside a SetResult path in a message, we'd create a deadlock by waiting. Multiplexer.EnsureSubscriptions(CommandFlags.FireAndForget); } + else if (SupportsSubscriptions && Multiplexer.RawConfig.Protocol > RedisProtocol.Resp2) + { + // interactive, and we wanted RESP3+, but we didn't get it; spin up pub/sub + Activate(ConnectionType.Subscription, null); + } if (IsConnected && (IsSubscriberConnected || !SupportsSubscriptions || isResp3)) { // Only connect on the second leg - we can accomplish this by checking both diff --git a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs index 6f80215dd..f4344047f 100644 --- a/tests/StackExchange.Redis.Tests/InProcessTestServer.cs +++ b/tests/StackExchange.Redis.Tests/InProcessTestServer.cs @@ -17,7 +17,8 @@ namespace StackExchange.Redis.Tests; public class InProcessTestServer : MemoryCacheRedisServer { private readonly ITestOutputHelper? _log; - public InProcessTestServer(ITestOutputHelper? log = null) + public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null) + : base(endpoint) { RedisVersion = RedisFeatures.v6_0_0; // for client to expect RESP3 _log = log; diff --git a/tests/StackExchange.Redis.Tests/Resp3HandshakeTests.cs b/tests/StackExchange.Redis.Tests/Resp3HandshakeTests.cs new file mode 100644 index 000000000..9ae7d556d --- /dev/null +++ b/tests/StackExchange.Redis.Tests/Resp3HandshakeTests.cs @@ -0,0 +1,118 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net; +using System.Threading.Tasks; +using StackExchange.Redis.Server; +using Xunit; + +namespace StackExchange.Redis.Tests; + +public class Resp3HandshakeTests(ITestOutputHelper log) +{ + public enum ServerResponse + { + Resp3, // up-level server style + Resp2, // DMC hybrid style, i.e. we know about it, but: "no, you'll take RESP2" + UnknownCommand, // down-level server style + } + + [Flags] + public enum HandshakeFlags + { + None = 0, + Authenticated = 1 << 0, + TieBreaker = 1 << 1, + ConfigChannel = 1 << 2, + UsePubSub = 1 << 3, + UseDatabase = 1 << 4, + } + + private static readonly int HandshakeFlagsCount = Enum.GetValues(typeof(HandshakeFlags)).Length - 1; + public static IEnumerable GetHandshakeParameters() + { + // all client protocols, all server-response modes; all flag permutations + var clients = (RedisProtocol[])Enum.GetValues(typeof(RedisProtocol)); + var servers = (ServerResponse[])Enum.GetValues(typeof(ServerResponse)); + foreach (var client in clients) + { + foreach (var server in servers) + { + if (client is RedisProtocol.Resp2 & server is not ServerResponse.Resp2) + { + // we don't issue HELLO for this, nothing to test + } + else + { + int count = 1 << HandshakeFlagsCount; + for (int i = 0; i < count; i++) + { + yield return [client, server, (HandshakeFlags)i]; + } + } + } + } + } + + [Theory] + [MemberData(nameof(GetHandshakeParameters))] + public async Task Handshake(RedisProtocol client, ServerResponse server, HandshakeFlags flags) + { + using var serverObj = new HandshakeServer(server, log); + serverObj.Password = (flags & HandshakeFlags.Authenticated) == 0 ? null : "mypassword"; + var config = serverObj.GetClientConfig(); + config.Protocol = client; + config.TieBreaker = (flags & HandshakeFlags.TieBreaker) == 0 ? "" : "tiebreaker_key"; + config.ConfigurationChannel = (flags & HandshakeFlags.ConfigChannel) == 0 ? "" : "broadcast_channel"; + + using var clientObj = await ConnectionMultiplexer.ConnectAsync(config); + + var sub = clientObj.GetSubscriber(); + var db = clientObj.GetDatabase(); + ConcurrentBag received = []; + RedisChannel channel = RedisChannel.Literal("mychannel"); + RedisKey key = "mykey"; + bool useDatabase = (flags & HandshakeFlags.UseDatabase) != 0; + bool usePubSub = (flags & HandshakeFlags.UsePubSub) != 0; + + if (usePubSub) + { + await sub.SubscribeAsync(channel, (x, y) => received.Add(y!)); + } + if (useDatabase) + { + await db.StringSetAsync(key, "myvalue"); + } + if (usePubSub) + { + await sub.PublishAsync(channel, "msg payload"); + for (int i = 0; i < 5 && received.IsEmpty; i++) + { + await Task.Delay(10, TestContext.Current.CancellationToken); + await sub.PingAsync(); + } + Assert.Equal("msg payload", Assert.Single(received)); + } + + if (useDatabase) + { + Assert.Equal("myvalue", await db.StringGetAsync(key)); + } + } + + private static readonly EndPoint EP = new DnsEndPoint("home", 8000); + private sealed class HandshakeServer(ServerResponse response, ITestOutputHelper log) + : InProcessTestServer(log, EP) + { + protected override RedisProtocol MaxProtocol => response switch + { + ServerResponse.Resp3 => RedisProtocol.Resp3, + _ => RedisProtocol.Resp2, + }; + + protected override TypedRedisValue Hello(RedisClient client, in RedisRequest request) + => response is ServerResponse.UnknownCommand + ? request.CommandNotFound() + : base.Hello(client, in request); + } +} diff --git a/toys/StackExchange.Redis.Server/RedisClient.cs b/toys/StackExchange.Redis.Server/RedisClient.cs index 1f09b3916..c1b097207 100644 --- a/toys/StackExchange.Redis.Server/RedisClient.cs +++ b/toys/StackExchange.Redis.Server/RedisClient.cs @@ -116,7 +116,6 @@ internal bool ShouldSkipResponse() public int Id { get; internal set; } public bool IsAuthenticated { get; internal set; } public RedisProtocol Protocol { get; internal set; } = RedisProtocol.Resp2; - public long ProtocolVersion => Protocol is RedisProtocol.Resp2 ? 2 : 3; public void Dispose() { diff --git a/toys/StackExchange.Redis.Server/RedisServer.cs b/toys/StackExchange.Redis.Server/RedisServer.cs index aa26e34a6..d16c6fb14 100644 --- a/toys/StackExchange.Redis.Server/RedisServer.cs +++ b/toys/StackExchange.Redis.Server/RedisServer.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Net; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using RESPite; @@ -158,7 +159,7 @@ protected override void AppendStats(StringBuilder sb) public override TypedRedisValue Execute(RedisClient client, in RedisRequest request) { var pw = Password; - if (pw.Length != 0 & !client.IsAuthenticated) + if (!string.IsNullOrEmpty(pw) & !client.IsAuthenticated) { if (!IsAuthCommand(request.KnownCommand)) return TypedRedisValue.Error("NOAUTH Authentication required."); @@ -190,6 +191,8 @@ protected virtual TypedRedisValue Auth(RedisClient client, in RedisRequest reque return TypedRedisValue.Error("ERR invalid password"); } + protected virtual RedisProtocol MaxProtocol => RedisProtocol.Resp3; + [RedisCommand(-1)] protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest request) { @@ -204,12 +207,14 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ case 2: protocol = RedisProtocol.Resp2; break; - case 3: // this client does not currently support RESP3 + case 3: protocol = RedisProtocol.Resp3; break; default: return TypedRedisValue.Error("NOPROTO unsupported protocol version"); } + protocol = (RedisProtocol)Math.Min((int)protocol, (int)MaxProtocol); + static TypedRedisValue ArgFail(in RespReader reader) => TypedRedisValue.Error($"ERR Syntax error in HELLO option '{reader.ReadString()}'\""); for (int i = 2; i < request.Count; i++) @@ -246,6 +251,12 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ } // all good, update client + long proto32 = protocol switch + { + >= RedisProtocol.Resp3 => 3, + >= RedisProtocol.Resp2 => 2, + _ => throw new InvalidOperationException($"Unexpected protocol: {protocol}"), + }; client.Protocol = protocol; client.IsAuthenticated = isAuthed; client.Name = name; @@ -256,7 +267,7 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ span[2] = TypedRedisValue.BulkString("version"); span[3] = TypedRedisValue.BulkString(VersionString); span[4] = TypedRedisValue.BulkString("proto"); - span[5] = TypedRedisValue.Integer(client.ProtocolVersion); + span[5] = TypedRedisValue.Integer(proto32); span[6] = TypedRedisValue.BulkString("id"); span[7] = TypedRedisValue.Integer(client.Id); span[8] = TypedRedisValue.BulkString("mode");