Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/StackExchange.Redis/ResultProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2894,10 +2894,18 @@ public override bool SetResult(PhysicalConnection connection, Message message, i

if (connection.Protocol is null)
{
// if we didn't get a valid response from HELLO, then we have to assume RESP2 at some point
// If we didn't get a valid response from HELLO, then we have to assume RESP2 at some point.
// We need the protocol assigned before OnFullyEstablished so that the
// protocol is reliably known *before* we do next-steps.
connection.SetProtocol(RedisProtocol.Resp2);
}

if (final & establishConnection)
{
// This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection.
connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}");
}

return final;
}

Expand Down Expand Up @@ -2939,11 +2947,6 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
}
if (happy)
{
if (establishConnection)
{
// This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection.
connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}");
}
SetResult(message, happy);
return true;
}
Expand Down
8 changes: 7 additions & 1 deletion src/StackExchange.Redis/ServerEndPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -695,14 +695,20 @@ internal void OnFullyEstablished(PhysicalConnection connection, string source)
// Clear the unselectable flag ASAP since we are open for business
ClearUnselectable(UnselectableFlags.DidNotRespond);

bool isResp3 = KnowOrAssumeResp3();
// is *this specific* connection using RESP3? (without reference to config preferences)
bool isResp3 = connection?.Protocol is >= RedisProtocol.Resp3;
if (bridge == subscription || isResp3)
{
// Note: this MUST be fire and forget, because we might be in the middle of a Sync processing
// TracerProcessor which is executing this line inside a SetResultCore().
// Since we're issuing commands inside a SetResult path in a message, we'd create a deadlock by waiting.
Multiplexer.EnsureSubscriptions(CommandFlags.FireAndForget);
}
else if (SupportsSubscriptions && Multiplexer.RawConfig.Protocol > RedisProtocol.Resp2)
{
// interactive, and we wanted RESP3+, but we didn't get it; spin up pub/sub
Activate(ConnectionType.Subscription, null);
}
if (IsConnected && (IsSubscriberConnected || !SupportsSubscriptions || isResp3))
{
// Only connect on the second leg - we can accomplish this by checking both
Expand Down
3 changes: 2 additions & 1 deletion tests/StackExchange.Redis.Tests/InProcessTestServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace StackExchange.Redis.Tests;
public class InProcessTestServer : MemoryCacheRedisServer
{
private readonly ITestOutputHelper? _log;
public InProcessTestServer(ITestOutputHelper? log = null)
public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null)
: base(endpoint)
{
RedisVersion = RedisFeatures.v6_0_0; // for client to expect RESP3
_log = log;
Expand Down
118 changes: 118 additions & 0 deletions tests/StackExchange.Redis.Tests/Resp3HandshakeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using StackExchange.Redis.Server;
using Xunit;

namespace StackExchange.Redis.Tests;

public class Resp3HandshakeTests(ITestOutputHelper log)
{
public enum ServerResponse
{
Resp3, // up-level server style
Resp2, // DMC hybrid style, i.e. we know about it, but: "no, you'll take RESP2"
UnknownCommand, // down-level server style
}

[Flags]
public enum HandshakeFlags
{
None = 0,
Authenticated = 1 << 0,
TieBreaker = 1 << 1,
ConfigChannel = 1 << 2,
UsePubSub = 1 << 3,
UseDatabase = 1 << 4,
}

private static readonly int HandshakeFlagsCount = Enum.GetValues(typeof(HandshakeFlags)).Length - 1;
public static IEnumerable<object[]> GetHandshakeParameters()
{
// all client protocols, all server-response modes; all flag permutations
var clients = (RedisProtocol[])Enum.GetValues(typeof(RedisProtocol));
var servers = (ServerResponse[])Enum.GetValues(typeof(ServerResponse));
foreach (var client in clients)
{
foreach (var server in servers)
{
if (client is RedisProtocol.Resp2 & server is not ServerResponse.Resp2)
{
// we don't issue HELLO for this, nothing to test
}
else
{
int count = 1 << HandshakeFlagsCount;
for (int i = 0; i < count; i++)
{
yield return [client, server, (HandshakeFlags)i];
}
}
}
}
}

[Theory]
[MemberData(nameof(GetHandshakeParameters))]
public async Task Handshake(RedisProtocol client, ServerResponse server, HandshakeFlags flags)
{
using var serverObj = new HandshakeServer(server, log);
serverObj.Password = (flags & HandshakeFlags.Authenticated) == 0 ? null : "mypassword";
var config = serverObj.GetClientConfig();
config.Protocol = client;
config.TieBreaker = (flags & HandshakeFlags.TieBreaker) == 0 ? "" : "tiebreaker_key";
config.ConfigurationChannel = (flags & HandshakeFlags.ConfigChannel) == 0 ? "" : "broadcast_channel";

using var clientObj = await ConnectionMultiplexer.ConnectAsync(config);

var sub = clientObj.GetSubscriber();
var db = clientObj.GetDatabase();
ConcurrentBag<string> received = [];
RedisChannel channel = RedisChannel.Literal("mychannel");
RedisKey key = "mykey";
bool useDatabase = (flags & HandshakeFlags.UseDatabase) != 0;
bool usePubSub = (flags & HandshakeFlags.UsePubSub) != 0;

if (usePubSub)
{
await sub.SubscribeAsync(channel, (x, y) => received.Add(y!));
}
if (useDatabase)
{
await db.StringSetAsync(key, "myvalue");
}
if (usePubSub)
{
await sub.PublishAsync(channel, "msg payload");
for (int i = 0; i < 5 && received.IsEmpty; i++)
{
await Task.Delay(10, TestContext.Current.CancellationToken);
await sub.PingAsync();
}
Assert.Equal("msg payload", Assert.Single(received));
}

if (useDatabase)
{
Assert.Equal("myvalue", await db.StringGetAsync(key));
}
}

private static readonly EndPoint EP = new DnsEndPoint("home", 8000);
private sealed class HandshakeServer(ServerResponse response, ITestOutputHelper log)
: InProcessTestServer(log, EP)
{
protected override RedisProtocol MaxProtocol => response switch
{
ServerResponse.Resp3 => RedisProtocol.Resp3,
_ => RedisProtocol.Resp2,
};

protected override TypedRedisValue Hello(RedisClient client, in RedisRequest request)
=> response is ServerResponse.UnknownCommand
? request.CommandNotFound()
: base.Hello(client, in request);
}
}
1 change: 0 additions & 1 deletion toys/StackExchange.Redis.Server/RedisClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ internal bool ShouldSkipResponse()
public int Id { get; internal set; }
public bool IsAuthenticated { get; internal set; }
public RedisProtocol Protocol { get; internal set; } = RedisProtocol.Resp2;
public long ProtocolVersion => Protocol is RedisProtocol.Resp2 ? 2 : 3;

public void Dispose()
{
Expand Down
17 changes: 14 additions & 3 deletions toys/StackExchange.Redis.Server/RedisServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Linq;
using System.Net;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using RESPite;
Expand Down Expand Up @@ -158,7 +159,7 @@ protected override void AppendStats(StringBuilder sb)
public override TypedRedisValue Execute(RedisClient client, in RedisRequest request)
{
var pw = Password;
if (pw.Length != 0 & !client.IsAuthenticated)
if (!string.IsNullOrEmpty(pw) & !client.IsAuthenticated)
{
if (!IsAuthCommand(request.KnownCommand))
return TypedRedisValue.Error("NOAUTH Authentication required.");
Expand Down Expand Up @@ -190,6 +191,8 @@ protected virtual TypedRedisValue Auth(RedisClient client, in RedisRequest reque
return TypedRedisValue.Error("ERR invalid password");
}

protected virtual RedisProtocol MaxProtocol => RedisProtocol.Resp3;

[RedisCommand(-1)]
protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest request)
{
Expand All @@ -204,12 +207,14 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
case 2:
protocol = RedisProtocol.Resp2;
break;
case 3: // this client does not currently support RESP3
case 3:
protocol = RedisProtocol.Resp3;
break;
default:
return TypedRedisValue.Error("NOPROTO unsupported protocol version");
}
protocol = (RedisProtocol)Math.Min((int)protocol, (int)MaxProtocol);

static TypedRedisValue ArgFail(in RespReader reader) => TypedRedisValue.Error($"ERR Syntax error in HELLO option '{reader.ReadString()}'\"");

for (int i = 2; i < request.Count; i++)
Expand Down Expand Up @@ -246,6 +251,12 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
}

// all good, update client
long proto32 = protocol switch
{
>= RedisProtocol.Resp3 => 3,
>= RedisProtocol.Resp2 => 2,
_ => throw new InvalidOperationException($"Unexpected protocol: {protocol}"),
};
client.Protocol = protocol;
client.IsAuthenticated = isAuthed;
client.Name = name;
Expand All @@ -256,7 +267,7 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
span[2] = TypedRedisValue.BulkString("version");
span[3] = TypedRedisValue.BulkString(VersionString);
span[4] = TypedRedisValue.BulkString("proto");
span[5] = TypedRedisValue.Integer(client.ProtocolVersion);
span[5] = TypedRedisValue.Integer(proto32);
span[6] = TypedRedisValue.BulkString("id");
span[7] = TypedRedisValue.Integer(client.Id);
span[8] = TypedRedisValue.BulkString("mode");
Expand Down
Loading