diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-26 23:42:05 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-26 23:42:05 +0200 |
| commit | e3dfc7a62fc1eb27a9fb68dd530064cdd2d5bb07 (patch) | |
| tree | 1be8b851ca34015ac6e16d2c60316e5cc8bb6ef6 | |
| parent | 637c4a09bfbe7045b0a639a616cfffc983da7e05 (diff) | |
Implement Raft append replication b85586a4-4eb9-4686-93c7-0ab14173baa5
| -rw-r--r-- | src/main/java/protocols/implementations/VSRaftProtocol.java | 81 | ||||
| -rw-r--r-- | src/test/java/protocols/implementations/VSRaftProtocolTest.java | 120 |
2 files changed, 193 insertions, 8 deletions
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java index 230c6d9..983c8d3 100644 --- a/src/main/java/protocols/implementations/VSRaftProtocol.java +++ b/src/main/java/protocols/implementations/VSRaftProtocol.java @@ -171,6 +171,7 @@ public class VSRaftProtocol extends VSAbstractProtocol { isCandidate = false; votesReceived = 0; voteResponsePids.clear(); + ackPids.clear(); leaderId = process.getProcessID(); lastHeartbeatTime = process.getTime(); isServer(true); @@ -186,6 +187,7 @@ public class VSRaftProtocol extends VSAbstractProtocol { currentContextIsServer(true); sendHeartbeat(); + sendAppendEntry(); currentContextIsServer(previousContextIsServer); } @@ -272,6 +274,29 @@ public class VSRaftProtocol extends VSAbstractProtocol { } /** + * Sends a simplified append-entry request for the configured log entry. + */ + private void sendAppendEntry() { + if (getVectorKeySet().contains("pids")) { + ackPids.addAll(getVector("pids")); + } + + if (ackPids.isEmpty()) { + return; + } + + logIndex++; + + VSMessage appendEntry = new VSMessage(); + appendEntry.setString("type", "appendEntry"); + appendEntry.setInteger("term", currentTerm); + appendEntry.setInteger("leaderId", leaderId); + appendEntry.setString("entry", getString("logEntry")); + appendEntry.setInteger("logIndex", logIndex); + sendMessage(appendEntry); + } + + /** * Dispatches Raft messages to the relevant handlers. * * @param recvMessage the received message @@ -283,6 +308,10 @@ public class VSRaftProtocol extends VSAbstractProtocol { handleVoteRequest(recvMessage); } else if ("voteResponse".equals(messageType)) { handleVoteResponse(recvMessage); + } else if ("appendEntry".equals(messageType)) { + handleAppendEntry(recvMessage); + } else if ("appendAck".equals(messageType)) { + handleAppendAck(recvMessage); } } @@ -345,6 +374,58 @@ public class VSRaftProtocol extends VSAbstractProtocol { } /** + * Handles an incoming append-entry request from the current leader. + * + * @param recvMessage the append-entry message + */ + private void handleAppendEntry(VSMessage recvMessage) { + int messageTerm = recvMessage.getInteger("term"); + int messageLeaderId = recvMessage.getInteger("leaderId"); + + if (messageTerm > currentTerm) { + becomeFollower(messageTerm, messageLeaderId); + } else if (messageTerm == currentTerm) { + leaderId = messageLeaderId; + isLeader = false; + isCandidate = false; + resetElectionTimeout(); + } else { + return; + } + + logIndex++; + + VSMessage appendAck = new VSMessage(); + appendAck.setString("type", "appendAck"); + appendAck.setInteger("term", currentTerm); + appendAck.setInteger("pid", process.getProcessID()); + appendAck.setInteger("logIndex", logIndex); + appendAck.setInteger("targetPid", messageLeaderId); + sendMessage(appendAck); + } + + /** + * Handles an append-entry acknowledgement on the leader. + * + * @param recvMessage the append acknowledgement + */ + private void handleAppendAck(VSMessage recvMessage) { + Integer responderPid = recvMessage.getIntegerObj("pid"); + + if (!isLeader || !isForMe(recvMessage) || responderPid == null || + !ackPids.contains(responderPid)) { + return; + } + + ackPids.remove(responderPid); + + if (ackPids.isEmpty()) { + commitIndex++; + log("Committed log index " + commitIndex); + } + } + + /** * Checks whether a directed response is meant for this process. * * @param recvMessage the received message diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java index 380aaae..f3dc0d6 100644 --- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java +++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java @@ -29,7 +29,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** - * Unit tests for VSRaftProtocol heartbeat behavior. + * Unit tests for VSRaftProtocol election and log replication behavior. */ class VSRaftProtocolTest { @@ -78,13 +78,19 @@ class VSRaftProtocolTest { protocol.onStart(); - verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockProcess, times(2)).sendMessage(messageCaptor.capture()); verify(mockTaskManager).addTask(taskCaptor.capture()); - VSMessage heartbeat = messageCaptor.getValue(); + VSMessage heartbeat = messageCaptor.getAllValues().get(0); assertEquals("heartbeat", heartbeat.getString("type")); assertEquals(0, heartbeat.getInteger("term")); assertEquals(7, heartbeat.getInteger("leaderId")); + VSMessage appendEntry = messageCaptor.getAllValues().get(1); + assertEquals("appendEntry", appendEntry.getString("type")); + assertEquals(0, appendEntry.getInteger("term")); + assertEquals(7, appendEntry.getInteger("leaderId")); + assertEquals("cmd1", appendEntry.getString("entry")); + assertEquals(1, appendEntry.getInteger("logIndex")); assertEquals(1600L, taskCaptor.getValue().getTaskTime()); } @@ -97,13 +103,13 @@ class VSRaftProtocolTest { protocol.onStart(); protocol.onServerScheduleStart(); - verify(mockProcess, times(2)).sendMessage(messageCaptor.capture()); + verify(mockProcess, times(3)).sendMessage(messageCaptor.capture()); verify(mockTaskManager, times(2)).addTask(taskCaptor.capture()); - assertEquals(2, messageCaptor.getAllValues().size()); + assertEquals(3, messageCaptor.getAllValues().size()); assertEquals(2, taskCaptor.getAllValues().size()); - VSMessage scheduledHeartbeat = messageCaptor.getAllValues().get(1); + VSMessage scheduledHeartbeat = messageCaptor.getAllValues().get(2); assertEquals("heartbeat", scheduledHeartbeat.getString("type")); assertEquals(0, scheduledHeartbeat.getInteger("term")); assertEquals(7, scheduledHeartbeat.getInteger("leaderId")); @@ -388,14 +394,20 @@ class VSRaftProtocolTest { protocol.onClientRecv(voteResponse); - verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockProcess, times(2)).sendMessage(messageCaptor.capture()); verify(mockTaskManager).removeAllTasks(any()); verify(mockTaskManager).addTask(taskCaptor.capture()); - VSMessage heartbeat = messageCaptor.getValue(); + VSMessage heartbeat = messageCaptor.getAllValues().get(0); assertEquals("heartbeat", heartbeat.getString("type")); assertEquals(3, heartbeat.getInteger("term")); assertEquals(7, heartbeat.getInteger("leaderId")); + VSMessage appendEntry = messageCaptor.getAllValues().get(1); + assertEquals("appendEntry", appendEntry.getString("type")); + assertEquals(3, appendEntry.getInteger("term")); + assertEquals(7, appendEntry.getInteger("leaderId")); + assertEquals("cmd1", appendEntry.getString("entry")); + assertEquals(1, appendEntry.getInteger("logIndex")); assertTrue(getBooleanField("isLeader")); assertFalse(getBooleanField("isCandidate")); assertEquals(7, getIntField("leaderId")); @@ -486,6 +498,92 @@ class VSRaftProtocolTest { assertFalse(getBooleanField("isLeader")); } + @Test + void testAppendEntryAcceptedByFollowerSendsAckAndAdvancesLogIndex() + throws Exception { + protocol.currentContextIsServer(false); + protocol.onClientInit(); + clearInvocations(mockProcess, mockTaskManager); + when(mockProcess.getTime()).thenReturn(600L, 600L); + + VSMessage appendEntry = new VSMessage(); + appendEntry.setString("type", "appendEntry"); + appendEntry.setInteger("term", 2); + appendEntry.setInteger("leaderId", 11); + appendEntry.setString("entry", "cmd2"); + appendEntry.setInteger("logIndex", 1); + + ArgumentCaptor<VSMessage> messageCaptor = + ArgumentCaptor.forClass(VSMessage.class); + ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class); + + protocol.onClientRecv(appendEntry); + + verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockTaskManager, times(2)).removeAllTasks(any()); + verify(mockTaskManager).addTask(taskCaptor.capture()); + + VSMessage appendAck = messageCaptor.getValue(); + assertEquals("appendAck", appendAck.getString("type")); + assertEquals(2, appendAck.getInteger("term")); + assertEquals(7, appendAck.getInteger("pid")); + assertEquals(1, appendAck.getInteger("logIndex")); + assertEquals(11, appendAck.getInteger("targetPid")); + assertEquals(2, getIntField("currentTerm")); + assertEquals(11, getIntField("leaderId")); + assertEquals(1, getIntField("logIndex")); + assertFalse(getBooleanField("isLeader")); + assertFalse(getBooleanField("isCandidate")); + assertEquals(5100L, taskCaptor.getValue().getTaskTime()); + } + + @Test + void testAppendAckForLeaderCommitsOnceAllFollowersAck() throws Exception { + setBooleanField("isLeader", true); + setIntField("logIndex", 1); + @SuppressWarnings("unchecked") + java.util.ArrayList<Integer> ackPids = + (java.util.ArrayList<Integer>) getObjectField("ackPids"); + ackPids.clear(); + ackPids.add(2); + + VSMessage appendAck = new VSMessage(); + appendAck.setString("type", "appendAck"); + appendAck.setInteger("term", 1); + appendAck.setInteger("pid", 2); + appendAck.setInteger("logIndex", 1); + appendAck.setInteger("targetPid", 7); + + protocol.onServerRecv(appendAck); + + verify(mockProcess).log("Committed log index 1"); + assertTrue(ackPids.isEmpty()); + assertEquals(1, getIntField("commitIndex")); + } + + @Test + void testAppendAckForDifferentLeaderTargetDoesNothing() throws Exception { + setBooleanField("isLeader", true); + @SuppressWarnings("unchecked") + java.util.ArrayList<Integer> ackPids = + (java.util.ArrayList<Integer>) getObjectField("ackPids"); + ackPids.clear(); + ackPids.add(2); + + VSMessage appendAck = new VSMessage(); + appendAck.setString("type", "appendAck"); + appendAck.setInteger("term", 1); + appendAck.setInteger("pid", 2); + appendAck.setInteger("logIndex", 1); + appendAck.setInteger("targetPid", 99); + + protocol.onServerRecv(appendAck); + + verify(mockProcess, never()).log(anyString()); + assertEquals(1, ackPids.size()); + assertEquals(0, getIntField("commitIndex")); + } + private void invokeBecomeFollower(int term, int leaderId) throws Exception { Method method = VSRaftProtocol.class.getDeclaredMethod( "becomeFollower", int.class, int.class); @@ -512,6 +610,12 @@ class VSRaftProtocolTest { return field.getInt(protocol); } + private Object getObjectField(String fieldName) throws Exception { + Field field = VSRaftProtocol.class.getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(protocol); + } + private boolean getBooleanField(String fieldName) throws Exception { Field field = VSRaftProtocol.class.getDeclaredField(fieldName); field.setAccessible(true); |
