diff --git a/src/main/java/com/trilead/ssh2/channel/ChannelManager.java b/src/main/java/com/trilead/ssh2/channel/ChannelManager.java index e097b6fa..f0678c14 100644 --- a/src/main/java/com/trilead/ssh2/channel/ChannelManager.java +++ b/src/main/java/com/trilead/ssh2/channel/ChannelManager.java @@ -1676,24 +1676,12 @@ public void msgChannelOpenFailure(byte[] msg, int msglen) throws IOException public void msgGlobalRequest(byte[] msg, int msglen) throws IOException { - /* Currently we do not support any kind of global request */ - TypesReader tr = new TypesReader(msg, 0, msglen); tr.readByte(); // skip packet type String requestName = tr.readString(); boolean wantReply = tr.readBoolean(); - if (wantReply) - { - byte[] reply_failure = new byte[1]; - reply_failure[0] = Packets.SSH_MSG_REQUEST_FAILURE; - - tm.sendAsynchronousMessage(reply_failure); - } - - /* We do not clean up the requestName String - that is OK for debug */ - if (log.isEnabled()) log.log(80, "Got SSH_MSG_GLOBAL_REQUEST (" + requestName + ")"); @@ -1702,20 +1690,33 @@ public void msgGlobalRequest(byte[] msg, int msglen) throws IOException try { PacketGlobalHostkeys hostkeys = new PacketGlobalHostkeys(msg, 0, msglen); processHostkeysAdvertisement(hostkeys, requestName); - } catch (IOException e) { + } catch (Exception e) { if (log.isEnabled()) - log.log(20, "Failed to parse hostkeys advertisement: " + e.getMessage()); + log.log(20, "Failed to process hostkeys advertisement: " + e.getMessage()); } + // hostkeys-00@openssh.com typically has wantReply=false, but if the + // server does request a reply, acknowledge it since we processed it. + if (wantReply) + { + byte[] reply_success = new byte[1]; + reply_success[0] = Packets.SSH_MSG_REQUEST_SUCCESS; + tm.sendAsynchronousMessage(reply_success); + } + return; } - } - public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException { - synchronized (channels) + if (wantReply) { - globalSuccessCounter++; - channels.notifyAll(); + byte[] reply_failure = new byte[1]; + reply_failure[0] = Packets.SSH_MSG_REQUEST_FAILURE; + + tm.sendAsynchronousMessage(reply_failure); } + } + public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException { + // Check for pending hostkeys-prove BEFORE incrementing the global counter, + // so the hostkeys-prove response doesn't interfere with other global request tracking. synchronized (hostkeysProveLock) { if (pendingHostkeysProve != null && !pendingHostkeysProve.completed) { try { @@ -1739,6 +1740,12 @@ public void msgGlobalSuccess(byte[] msg, int msglen) throws IOException { } } + synchronized (channels) + { + globalSuccessCounter++; + channels.notifyAll(); + } + if (log.isEnabled()) log.log(80, "Got SSH_MSG_REQUEST_SUCCESS"); } @@ -1882,19 +1889,28 @@ private void processHostkeysAdvertisement(PacketGlobalHostkeys hostkeys, String } List knownAlgos = extVerifier.getKnownKeyAlgorithmsForHost(hostname, port); - Set knownAlgoSet = (knownAlgos != null) ? new HashSet<>(knownAlgos) : new HashSet<>(); + + // Normalize algorithm names so RSA signature variants (rsa-sha2-256, + // rsa-sha2-512) are treated as the same key type as ssh-rsa. + Set normalizedKnownAlgoSet = new HashSet<>(); + if (knownAlgos != null) { + for (String algo : knownAlgos) { + normalizedKnownAlgoSet.add(normalizeKeyAlgorithm(algo)); + } + } List newKeys = new ArrayList<>(); - Set advertisedAlgoSet = new HashSet<>(); + Set normalizedAdvertisedAlgoSet = new HashSet<>(); for (byte[] keyBlob : advertisedKeys) { String keyAlgo = extractKeyAlgorithm(keyBlob); if (keyAlgo == null) continue; - advertisedAlgoSet.add(keyAlgo); + String normalizedAlgo = normalizeKeyAlgorithm(keyAlgo); + normalizedAdvertisedAlgoSet.add(normalizedAlgo); - if (!knownAlgoSet.contains(keyAlgo)) { + if (!normalizedKnownAlgoSet.contains(normalizedAlgo)) { newKeys.add(keyBlob); } } @@ -1903,7 +1919,7 @@ private void processHostkeysAdvertisement(PacketGlobalHostkeys hostkeys, String for (String knownAlgo : knownAlgos) { if (knownAlgo == null) continue; - if (!advertisedAlgoSet.contains(knownAlgo)) { + if (!normalizedAdvertisedAlgoSet.contains(normalizeKeyAlgorithm(knownAlgo))) { extVerifier.removeServerHostKey(hostname, port, knownAlgo, null); if (log.isEnabled()) log.log(50, "Removed hostkey algorithm no longer advertised: " + knownAlgo); @@ -2050,6 +2066,21 @@ private String extractKeyAlgorithm(byte[] keyBlob) throws IOException return tr.readString(); } + /** + * Normalizes RSA algorithm variants to a canonical form for comparison. + * In SSH, rsa-sha2-256 and rsa-sha2-512 use the same RSA key as ssh-rsa + * (the difference is only the signature hash algorithm). Key blobs always + * identify as ssh-rsa regardless of which signature algorithm was negotiated. + */ + static String normalizeKeyAlgorithm(String algorithm) + { + if (RSASHA256Verify.ID_RSA_SHA_2_256.equals(algorithm) || + RSASHA512Verify.ID_RSA_SHA_2_512.equals(algorithm)) { + return RSASHA1Verify.ID_SSH_RSA; + } + return algorithm; + } + private SSHSignature getSignatureVerifier(String algorithm) { if (RSASHA1Verify.ID_SSH_RSA.equals(algorithm)) diff --git a/src/test/java/com/trilead/ssh2/channel/ChannelManagerTest.java b/src/test/java/com/trilead/ssh2/channel/ChannelManagerTest.java index 3452d3e9..2262d2a8 100644 --- a/src/test/java/com/trilead/ssh2/channel/ChannelManagerTest.java +++ b/src/test/java/com/trilead/ssh2/channel/ChannelManagerTest.java @@ -5,15 +5,20 @@ import com.trilead.ssh2.packets.PacketGlobalHostkeys; import com.trilead.ssh2.packets.Packets; import com.trilead.ssh2.packets.TypesWriter; +import com.trilead.ssh2.signature.RSASHA1Verify; +import com.trilead.ssh2.signature.RSASHA256Verify; +import com.trilead.ssh2.signature.RSASHA512Verify; import com.trilead.ssh2.transport.ITransportConnection; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -22,9 +27,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.nullable; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -778,4 +787,178 @@ public void testMsgChannelOpenWithUnknownType() throws IOException { channelManager.msgChannelOpen(msg, offset); verify(mockTransportConnection).sendAsynchronousMessage(any(byte[].class)); } + + // ---- Host key rotation tests ---- + + /** + * Build an SSH key blob whose algorithm identifier is the given string. + * The blob is: uint32(len) + algorithm_name + uint32(len) + dummy_data. + * This is enough for extractKeyAlgorithm() which only reads the first string. + */ + private byte[] buildKeyBlob(String algorithm) { + TypesWriter tw = new TypesWriter(); + tw.writeString(algorithm); + // Append a dummy "key data" field so it looks like a plausible key blob + tw.writeString(new byte[]{0x00, 0x01, 0x02, 0x03}, 0, 4); + return tw.getBytes(); + } + + /** + * Build an SSH_MSG_GLOBAL_REQUEST message for hostkeys-00@openssh.com + * containing the given host key blobs. + */ + private byte[] buildHostkeysGlobalRequest(String requestName, boolean wantReply, byte[]... keyBlobs) { + TypesWriter tw = new TypesWriter(); + tw.writeByte(Packets.SSH_MSG_GLOBAL_REQUEST); + tw.writeString(requestName); + tw.writeBoolean(wantReply); + for (byte[] blob : keyBlobs) { + tw.writeString(blob, 0, blob.length); + } + return tw.getBytes(); + } + + /** + * Regression test for GitHub issue connectbot/connectbot#2023: + * + * Before the fix, when the stored algorithm was "rsa-sha2-512" and the + * key blob contained "ssh-rsa", processHostkeysAdvertisement would call + * removeServerHostKey with null hostKey (crashing Kotlin callers). + * + * After the fix, RSA algorithm variants are normalized so "rsa-sha2-512" + * and "ssh-rsa" are recognized as the same key type. removeServerHostKey + * should NOT be called, since the key is still present. + */ + @Test + public void testHostkeysAdvertisement_rsaAlgoMismatch_noRemovalAfterFix() throws Exception { + ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class); + when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt())) + .thenReturn(Collections.singletonList(RSASHA512Verify.ID_RSA_SHA_2_512)); + + when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier); + when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com"); + when(mockTransportConnection.getPort()).thenReturn(22); + + byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA); + byte[] msg = buildHostkeysGlobalRequest( + "hostkeys-00@openssh.com", false, rsaKeyBlob); + + channelManager.handleMessage(msg, msg.length); + + // After fix: rsa-sha2-512 is normalized to ssh-rsa, so the algorithm + // is recognized as still advertised. removeServerHostKey must NOT be called. + verify(mockVerifier, never()).removeServerHostKey( + anyString(), anyInt(), anyString(), any()); + } + + /** + * Regression test: even if removeServerHostKey throws (e.g., Kotlin's + * non-nullable parameter check), it must not propagate out of handleMessage + * and kill the SSH receiver thread. + */ + @Test + public void testHostkeysAdvertisement_removeThrows_doesNotCrashReceiverThread() throws Exception { + // Use a non-RSA algorithm so normalization doesn't prevent the removal call. + // Pretend the client knows "ssh-dss" but server no longer advertises it. + ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class); + when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt())) + .thenReturn(Collections.singletonList("ssh-dss")); + doThrow(new NullPointerException("Parameter specified as non-null is null: parameter hostKey")) + .when(mockVerifier).removeServerHostKey(anyString(), anyInt(), anyString(), nullable(byte[].class)); + + when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier); + when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com"); + when(mockTransportConnection.getPort()).thenReturn(22); + + // Server advertises only an ed25519 key (no DSS) + byte[] ed25519KeyBlob = buildKeyBlob("ssh-ed25519"); + byte[] msg = buildHostkeysGlobalRequest( + "hostkeys-00@openssh.com", false, ed25519KeyBlob); + + // Even if removeServerHostKey throws, handleMessage must NOT propagate it + channelManager.handleMessage(msg, msg.length); + + // The call did happen (and threw), but the exception was caught + verify(mockVerifier).removeServerHostKey( + anyString(), anyInt(), anyString(), nullable(byte[].class)); + } + + /** + * Verify that when the stored algorithm matches the advertised key blob + * algorithm (both "ssh-rsa"), removeServerHostKey is NOT called. + * This is the baseline: no mismatch, no problem. + */ + @Test + public void testHostkeysAdvertisement_matchingAlgo_noRemoval() throws Exception { + ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class); + when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt())) + .thenReturn(Collections.singletonList(RSASHA1Verify.ID_SSH_RSA)); + + when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier); + when(mockTransportConnection.getHostname()).thenReturn("esxi.example.com"); + when(mockTransportConnection.getPort()).thenReturn(22); + + byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA); + byte[] msg = buildHostkeysGlobalRequest( + "hostkeys-00@openssh.com", false, rsaKeyBlob); + + channelManager.handleMessage(msg, msg.length); + + // No algorithm mismatch, so removeServerHostKey should not be called + verify(mockVerifier, never()).removeServerHostKey( + anyString(), anyInt(), anyString(), any()); + } + + // ---- normalizeKeyAlgorithm tests ---- + + @Test + public void testNormalizeKeyAlgorithm_rsaSha2_512() { + assertEquals(RSASHA1Verify.ID_SSH_RSA, + ChannelManager.normalizeKeyAlgorithm(RSASHA512Verify.ID_RSA_SHA_2_512)); + } + + @Test + public void testNormalizeKeyAlgorithm_rsaSha2_256() { + assertEquals(RSASHA1Verify.ID_SSH_RSA, + ChannelManager.normalizeKeyAlgorithm(RSASHA256Verify.ID_RSA_SHA_2_256)); + } + + @Test + public void testNormalizeKeyAlgorithm_sshRsa_unchanged() { + assertEquals(RSASHA1Verify.ID_SSH_RSA, + ChannelManager.normalizeKeyAlgorithm(RSASHA1Verify.ID_SSH_RSA)); + } + + @Test + public void testNormalizeKeyAlgorithm_nonRsa_unchanged() { + assertEquals("ssh-ed25519", + ChannelManager.normalizeKeyAlgorithm("ssh-ed25519")); + assertEquals("ssh-dss", + ChannelManager.normalizeKeyAlgorithm("ssh-dss")); + assertEquals("ecdsa-sha2-nistp256", + ChannelManager.normalizeKeyAlgorithm("ecdsa-sha2-nistp256")); + } + + // ---- msgGlobalRequest hostkeys reply tests ---- + + @Test + public void testMsgGlobalRequest_hostkeys_withReply_sendsSuccess() throws Exception { + ExtendedServerHostKeyVerifier mockVerifier = mock(ExtendedServerHostKeyVerifier.class); + when(mockVerifier.getKnownKeyAlgorithmsForHost(anyString(), anyInt())).thenReturn(null); + when(mockTransportConnection.getServerHostKeyVerifier()).thenReturn(mockVerifier); + when(mockTransportConnection.getHostname()).thenReturn("example.com"); + when(mockTransportConnection.getPort()).thenReturn(22); + + byte[] rsaKeyBlob = buildKeyBlob(RSASHA1Verify.ID_SSH_RSA); + byte[] msg = buildHostkeysGlobalRequest( + "hostkeys-00@openssh.com", true, rsaKeyBlob); + + channelManager.handleMessage(msg, msg.length); + + // Should send REQUEST_SUCCESS (not FAILURE) for handled hostkeys request + ArgumentCaptor replyCaptor = ArgumentCaptor.forClass(byte[].class); + verify(mockTransportConnection).sendAsynchronousMessage(replyCaptor.capture()); + assertEquals(Packets.SSH_MSG_REQUEST_SUCCESS, replyCaptor.getValue()[0], + "Hostkeys global request should get SUCCESS reply, not FAILURE"); + } }