Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions src/main/java/com/trilead/ssh2/channel/ChannelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1676,24 +1676,12 @@ public void msgChannelOpenFailure(byte[] msg, int msglen) throws IOException

public void msgGlobalRequest(byte[] msg, int msglen) throws IOException
{
/* Currently we do not support any kind of global request */

TypesReader tr = new TypesReader(msg, 0, msglen);

tr.readByte(); // skip packet type
String requestName = tr.readString();
boolean wantReply = tr.readBoolean();

if (wantReply)
{
byte[] reply_failure = new byte[1];
reply_failure[0] = Packets.SSH_MSG_REQUEST_FAILURE;

tm.sendAsynchronousMessage(reply_failure);
}

/* We do not clean up the requestName String - that is OK for debug */

if (log.isEnabled())
log.log(80, "Got SSH_MSG_GLOBAL_REQUEST (" + requestName + ")");

Expand All @@ -1702,20 +1690,33 @@ public void msgGlobalRequest(byte[] msg, int msglen) throws IOException
try {
PacketGlobalHostkeys hostkeys = new PacketGlobalHostkeys(msg, 0, msglen);
processHostkeysAdvertisement(hostkeys, requestName);
} catch (IOException e) {
} catch (Exception e) {
if (log.isEnabled())
log.log(20, "Failed to parse hostkeys advertisement: " + e.getMessage());
log.log(20, "Failed to process hostkeys advertisement: " + e.getMessage());
}
// hostkeys-00@openssh.com typically has wantReply=false, but if the
// server does request a reply, acknowledge it since we processed it.
if (wantReply)
{
byte[] reply_success = new byte[1];
reply_success[0] = Packets.SSH_MSG_REQUEST_SUCCESS;
tm.sendAsynchronousMessage(reply_success);
}
return;
}
}

public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException {
synchronized (channels)
if (wantReply)
{
globalSuccessCounter++;
channels.notifyAll();
byte[] reply_failure = new byte[1];
reply_failure[0] = Packets.SSH_MSG_REQUEST_FAILURE;

tm.sendAsynchronousMessage(reply_failure);
}
}

public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException {
// Check for pending hostkeys-prove BEFORE incrementing the global counter,
// so the hostkeys-prove response doesn't interfere with other global request tracking.
synchronized (hostkeysProveLock) {
if (pendingHostkeysProve != null && !pendingHostkeysProve.completed) {
try {
Expand All @@ -1739,6 +1740,12 @@ public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException {
}
}

synchronized (channels)
{
globalSuccessCounter++;
channels.notifyAll();
}

if (log.isEnabled())
log.log(80, "Got SSH_MSG_REQUEST_SUCCESS");
}
Expand Down Expand Up @@ -1882,19 +1889,28 @@ private void processHostkeysAdvertisement(PacketGlobalHostkeys hostkeys, String
}

List<String> knownAlgos = extVerifier.getKnownKeyAlgorithmsForHost(hostname, port);
Set<String> knownAlgoSet = (knownAlgos != null) ? new HashSet<>(knownAlgos) : new HashSet<>();

// Normalize algorithm names so RSA signature variants (rsa-sha2-256,
// rsa-sha2-512) are treated as the same key type as ssh-rsa.
Set<String> normalizedKnownAlgoSet = new HashSet<>();
if (knownAlgos != null) {
for (String algo : knownAlgos) {
normalizedKnownAlgoSet.add(normalizeKeyAlgorithm(algo));
}
}

List<byte[]> newKeys = new ArrayList<>();
Set<String> advertisedAlgoSet = new HashSet<>();
Set<String> normalizedAdvertisedAlgoSet = new HashSet<>();

for (byte[] keyBlob : advertisedKeys) {
String keyAlgo = extractKeyAlgorithm(keyBlob);
if (keyAlgo == null)
continue;

advertisedAlgoSet.add(keyAlgo);
String normalizedAlgo = normalizeKeyAlgorithm(keyAlgo);
normalizedAdvertisedAlgoSet.add(normalizedAlgo);

if (!knownAlgoSet.contains(keyAlgo)) {
if (!normalizedKnownAlgoSet.contains(normalizedAlgo)) {
newKeys.add(keyBlob);
}
}
Expand All @@ -1903,7 +1919,7 @@ private void processHostkeysAdvertisement(PacketGlobalHostkeys hostkeys, String
for (String knownAlgo : knownAlgos) {
if (knownAlgo == null)
continue;
if (!advertisedAlgoSet.contains(knownAlgo)) {
if (!normalizedAdvertisedAlgoSet.contains(normalizeKeyAlgorithm(knownAlgo))) {
extVerifier.removeServerHostKey(hostname, port, knownAlgo, null);
if (log.isEnabled())
log.log(50, "Removed hostkey algorithm no longer advertised: " + knownAlgo);
Expand Down Expand Up @@ -2050,6 +2066,21 @@ private String extractKeyAlgorithm(byte[] keyBlob) throws IOException
return tr.readString();
}

/**
* Normalizes RSA algorithm variants to a canonical form for comparison.
* In SSH, rsa-sha2-256 and rsa-sha2-512 use the same RSA key as ssh-rsa
* (the difference is only the signature hash algorithm). Key blobs always
* identify as ssh-rsa regardless of which signature algorithm was negotiated.
*/
static String normalizeKeyAlgorithm(String algorithm)
{
if (RSASHA256Verify.ID_RSA_SHA_2_256.equals(algorithm) ||
RSASHA512Verify.ID_RSA_SHA_2_512.equals(algorithm)) {
return RSASHA1Verify.ID_SSH_RSA;
}
return algorithm;
}

private SSHSignature getSignatureVerifier(String algorithm)
{
if (RSASHA1Verify.ID_SSH_RSA.equals(algorithm))
Expand Down
183 changes: 183 additions & 0 deletions src/test/java/com/trilead/ssh2/channel/ChannelManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
import com.trilead.ssh2.packets.PacketGlobalHostkeys;
import com.trilead.ssh2.packets.Packets;
import com.trilead.ssh2.packets.TypesWriter;
import com.trilead.ssh2.signature.RSASHA1Verify;
import com.trilead.ssh2.signature.RSASHA256Verify;
import com.trilead.ssh2.signature.RSASHA512Verify;
import com.trilead.ssh2.transport.ITransportConnection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand All @@ -22,9 +27,13 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.nullable;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -778,4 +787,178 @@ public void testMsgChannelOpenWithUnknownType() throws IOException {
channelManager.msgChannelOpen(msg, offset);
verify(mockTransportConnection).sendAsynchronousMessage(any(byte[].class));
}

// ---- Host key rotation tests ----

/**
* Build an SSH key blob whose algorithm identifier is the given string.
* The blob is: uint32(len) + algorithm_name + uint32(len) + dummy_data.
* This is enough for extractKeyAlgorithm() which only reads the first string.
*/
private byte[] buildKeyBlob(String algorithm) {
TypesWriter tw = new TypesWriter();
tw.writeString(algorithm);
// Append a dummy "key data" field so it looks like a plausible key blob
tw.writeString(new byte[]{0x00, 0x01, 0x02, 0x03}, 0, 4);
return tw.getBytes();
}

/**
* Build an SSH_MSG_GLOBAL_REQUEST message for hostkeys-00@openssh.com
* containing the given host key blobs.
*/
private byte[] buildHostkeysGlobalRequest(String requestName, boolean wantReply, byte[]... keyBlobs) {
TypesWriter tw = new TypesWriter();
tw.writeByte(Packets.SSH_MSG_GLOBAL_REQUEST);
tw.writeString(requestName);
tw.writeBoolean(wantReply);
for (byte[] blob : keyBlobs) {
tw.writeString(blob, 0, blob.length);
}
return tw.getBytes();
}

/**
* Regression test for GitHub issue connectbot/connectbot#2023:
*
* Before the fix, when the stored algorithm was "rsa-sha2-512" and the
* key blob contained "ssh-rsa", processHostkeysAdvertisement would call
* removeServerHostKey with null hostKey (crashing Kotlin callers).
*
* After the fix, RSA algorithm variants are normalized so "rsa-sha2-512"
* and "ssh-rsa" are recognized as the same key type. removeServerHostKey
* should NOT be called, since the key is still present.
*/
@Test
public void testHostkeysAdvertisement_rsaAlgoMismatch_noRemovalAfterFix() throws Exception {
ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class);
when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt()))
.thenReturn(Collections.singletonList(RSASHA512Verify.ID_RSA_SHA_2_512));

when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier);
when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com");
when(mockTransportConnection.getPort()).thenReturn(22);

byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA);
byte[] msg = buildHostkeysGlobalRequest(
"hostkeys-00@openssh.com", false, rsaKeyBlob);

channelManager.handleMessage(msg, msg.length);

// After fix: rsa-sha2-512 is normalized to ssh-rsa, so the algorithm
// is recognized as still advertised. removeServerHostKey must NOT be called.
verify(mockVerifier, never()).removeServerHostKey(
anyString(), anyInt(), anyString(), any());
}

/**
* Regression test: even if removeServerHostKey throws (e.g., Kotlin's
* non-nullable parameter check), it must not propagate out of handleMessage
* and kill the SSH receiver thread.
*/
@Test
public void testHostkeysAdvertisement_removeThrows_doesNotCrashReceiverThread() throws Exception {
// Use a non-RSA algorithm so normalization doesn't prevent the removal call.
// Pretend the client knows "ssh-dss" but server no longer advertises it.
ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class);
when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt()))
.thenReturn(Collections.singletonList("ssh-dss"));
doThrow(new NullPointerException("Parameter specified as non-null is null: parameter hostKey"))
.when(mockVerifier).removeServerHostKey(anyString(), anyInt(), anyString(), nullable(byte[].class));

when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier);
when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com");
when(mockTransportConnection.getPort()).thenReturn(22);

// Server advertises only an ed25519 key (no DSS)
byte[] ed25519KeyBlob = buildKeyBlob("ssh-ed25519");
byte[] msg = buildHostkeysGlobalRequest(
"hostkeys-00@openssh.com", false, ed25519KeyBlob);

// Even if removeServerHostKey throws, handleMessage must NOT propagate it
channelManager.handleMessage(msg, msg.length);

// The call did happen (and threw), but the exception was caught
verify(mockVerifier).removeServerHostKey(
anyString(), anyInt(), anyString(), nullable(byte[].class));
}

/**
* Verify that when the stored algorithm matches the advertised key blob
* algorithm (both "ssh-rsa"), removeServerHostKey is NOT called.
* This is the baseline: no mismatch, no problem.
*/
@Test
public void testHostkeysAdvertisement_matchingAlgo_noRemoval() throws Exception {
ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class);
when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt()))
.thenReturn(Collections.singletonList(RSASHA1Verify.ID_SSH_RSA));

when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier);
when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com");
when(mockTransportConnection.getPort()).thenReturn(22);

byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA);
byte[] msg = buildHostkeysGlobalRequest(
"hostkeys-00@openssh.com", false, rsaKeyBlob);

channelManager.handleMessage(msg, msg.length);

// No algorithm mismatch, so removeServerHostKey should not be called
verify(mockVerifier, never()).removeServerHostKey(
anyString(), anyInt(), anyString(), any());
}

// ---- normalizeKeyAlgorithm tests ----

@Test
public void testNormalizeKeyAlgorithm_rsaSha2_512() {
assertEquals(RSASHA1Verify.ID_SSH_RSA,
ChannelManager.normalizeKeyAlgorithm(RSASHA512Verify.ID_RSA_SHA_2_512));
}

@Test
public void testNormalizeKeyAlgorithm_rsaSha2_256() {
assertEquals(RSASHA1Verify.ID_SSH_RSA,
ChannelManager.normalizeKeyAlgorithm(RSASHA256Verify.ID_RSA_SHA_2_256));
}

@Test
public void testNormalizeKeyAlgorithm_sshRsa_unchanged() {
assertEquals(RSASHA1Verify.ID_SSH_RSA,
ChannelManager.normalizeKeyAlgorithm(RSASHA1Verify.ID_SSH_RSA));
}

@Test
public void testNormalizeKeyAlgorithm_nonRsa_unchanged() {
assertEquals("ssh-ed25519",
ChannelManager.normalizeKeyAlgorithm("ssh-ed25519"));
assertEquals("ssh-dss",
ChannelManager.normalizeKeyAlgorithm("ssh-dss"));
assertEquals("ecdsa-sha2-nistp256",
ChannelManager.normalizeKeyAlgorithm("ecdsa-sha2-nistp256"));
}

// ---- msgGlobalRequest hostkeys reply tests ----

@Test
public void testMsgGlobalRequest_hostkeys_withReply_sendsSuccess() throws Exception {
ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class);
when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt())).thenReturn(null);
when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier);
when(mockTransportConnection.getHostname()).thenReturn("example.com");
when(mockTransportConnection.getPort()).thenReturn(22);

byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA);
byte[] msg = buildHostkeysGlobalRequest(
"hostkeys-00@openssh.com", true, rsaKeyBlob);

channelManager.handleMessage(msg, msg.length);

// Should send REQUEST_SUCCESS (not FAILURE) for handled hostkeys request
ArgumentCaptor<byte[]> replyCaptor = ArgumentCaptor.forClass(byte[].class);
verify(mockTransportConnection).sendAsynchronousMessage(replyCaptor.capture());
assertEquals(Packets.SSH_MSG_REQUEST_SUCCESS, replyCaptor.getValue()[0],
"Hostkeys global request should get SUCCESS reply, not FAILURE");
}
}
Loading