summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/java/protocols/VSAbstractProtocol.java9
-rw-r--r--src/main/java/protocols/implementations/VSRaftProtocol.java16
-rw-r--r--src/test/java/protocols/implementations/VSRaftProtocolTest.java137
3 files changed, 160 insertions, 2 deletions
diff --git a/src/main/java/protocols/VSAbstractProtocol.java b/src/main/java/protocols/VSAbstractProtocol.java
index 505a6aa..1695c25 100644
--- a/src/main/java/protocols/VSAbstractProtocol.java
+++ b/src/main/java/protocols/VSAbstractProtocol.java
@@ -247,6 +247,15 @@ abstract public class VSAbstractProtocol extends VSAbstractEvent {
}
/**
+ * Checks whether the protocol currently runs in server context.
+ *
+ * @return true if the current context is server, otherwise false
+ */
+ public final boolean currentContextIsServer() {
+ return currentContextIsServer;
+ }
+
+ /**
* Checks how the protocol will start
*
* @return true, if this protocol uses onServerStart instead of
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java
index 4a310ee..eaf63a7 100644
--- a/src/main/java/protocols/implementations/VSRaftProtocol.java
+++ b/src/main/java/protocols/implementations/VSRaftProtocol.java
@@ -166,6 +166,7 @@ public class VSRaftProtocol extends VSAbstractProtocol {
* @param newLeaderId the known leader in that term, or -1 if unknown
*/
private void becomeFollower(int term, int newLeaderId) {
+ clearServerSchedules();
isLeader = false;
isCandidate = false;
currentTerm = term;
@@ -181,9 +182,24 @@ public class VSRaftProtocol extends VSAbstractProtocol {
private void resetElectionTimeout() {
long jitterPercentage = Math.abs(process.getRandomPercentage());
long jitter = (getLong("electionJitter") * jitterPercentage) / 100L;
+ boolean previousContextIsServer = currentContextIsServer();
+ currentContextIsServer(false);
removeSchedules();
scheduleAt(process.getTime() + getLong("electionTimeout") + jitter);
+ currentContextIsServer(previousContextIsServer);
+ }
+
+ /**
+ * Clears any active server-side schedules while preserving the caller
+ * context.
+ */
+ private void clearServerSchedules() {
+ boolean previousContextIsServer = currentContextIsServer();
+
+ currentContextIsServer(true);
+ removeSchedules();
+ currentContextIsServer(previousContextIsServer);
}
/**
diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
index 6410a4e..7f4f867 100644
--- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java
+++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
@@ -5,6 +5,7 @@ import core.VSMessage;
import core.VSTask;
import core.VSTaskManager;
import core.time.VSVectorTime;
+import events.internal.VSProtocolScheduleEvent;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
@@ -13,8 +14,15 @@ import org.mockito.MockitoAnnotations;
import prefs.VSPrefs;
import simulator.VSSimulatorVisualization;
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -58,6 +66,7 @@ class VSRaftProtocolTest {
when(mockPrefs.getString(anyString())).thenReturn("TestString");
when(mockProcess.getTime()).thenReturn(100L);
when(mockProcess.getProcessID()).thenReturn(7);
+ when(mockProcess.getRandomPercentage()).thenReturn(25);
}
@Test
@@ -106,7 +115,131 @@ class VSRaftProtocolTest {
protocol.onServerSchedule();
- verify(mockProcess, never()).sendMessage(org.mockito.ArgumentMatchers.any());
- verify(mockTaskManager, never()).addTask(org.mockito.ArgumentMatchers.any());
+ verify(mockProcess, never()).sendMessage(any());
+ verify(mockTaskManager, never()).addTask(any());
+ }
+
+ @Test
+ void testOnClientInitSchedulesRandomizedElectionTimeout() {
+ protocol.currentContextIsServer(false);
+
+ ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class);
+
+ protocol.onClientInit();
+
+ verify(mockTaskManager).removeAllTasks(any());
+ verify(mockTaskManager).addTask(taskCaptor.capture());
+
+ VSTask task = taskCaptor.getValue();
+ assertEquals(4600L, task.getTaskTime());
+ assertFalse(((VSProtocolScheduleEvent) task.getEvent()).isServerSchedule());
+ }
+
+ @Test
+ void testOnClientScheduleStartsElectionAfterTimeout() throws Exception {
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getTime()).thenReturn(4200L);
+
+ ArgumentCaptor<VSMessage> messageCaptor =
+ ArgumentCaptor.forClass(VSMessage.class);
+ ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class);
+
+ protocol.onClientSchedule();
+
+ verify(mockProcess).sendMessage(messageCaptor.capture());
+ verify(mockTaskManager).removeAllTasks(any());
+ verify(mockTaskManager).addTask(taskCaptor.capture());
+
+ VSMessage voteRequest = messageCaptor.getValue();
+ assertEquals("voteRequest", voteRequest.getString("type"));
+ assertEquals(1, voteRequest.getInteger("term"));
+ assertEquals(7, voteRequest.getInteger("candidateId"));
+ assertTrue(getBooleanField("isCandidate"));
+ assertFalse(getBooleanField("isLeader"));
+ assertEquals(1, getIntField("votesReceived"));
+ assertEquals(7, getIntField("votedFor"));
+ assertEquals(8700L, taskCaptor.getValue().getTaskTime());
+ assertFalse(
+ ((VSProtocolScheduleEvent) taskCaptor.getValue().getEvent())
+ .isServerSchedule());
+ }
+
+ @Test
+ void testCandidateTimeoutStartsNewElectionAndReschedules() throws Exception {
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ when(mockProcess.getTime()).thenReturn(4200L, 4200L, 4200L);
+
+ protocol.onClientSchedule();
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getTime()).thenReturn(8401L, 8401L, 8401L);
+
+ ArgumentCaptor<VSMessage> messageCaptor =
+ ArgumentCaptor.forClass(VSMessage.class);
+ ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class);
+
+ protocol.onClientSchedule();
+
+ verify(mockProcess).sendMessage(messageCaptor.capture());
+ verify(mockTaskManager).removeAllTasks(any());
+ verify(mockTaskManager).addTask(taskCaptor.capture());
+
+ VSMessage voteRequest = messageCaptor.getValue();
+ assertEquals("voteRequest", voteRequest.getString("type"));
+ assertEquals(2, voteRequest.getInteger("term"));
+ assertEquals(2, getIntField("currentTerm"));
+ assertEquals(1, getIntField("votesReceived"));
+ assertEquals(12901L, taskCaptor.getValue().getTaskTime());
+ }
+
+ @Test
+ void testBecomeFollowerFromServerContextCancelsHeartbeatsAndRearmsClientTimeout()
+ throws Exception {
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ clearInvocations(mockProcess, mockTaskManager);
+
+ protocol.onStart();
+ clearInvocations(mockProcess, mockTaskManager);
+ protocol.currentContextIsServer(true);
+ when(mockProcess.getTime()).thenReturn(300L);
+
+ ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class);
+
+ invokeBecomeFollower(4, 11);
+
+ verify(mockTaskManager, times(2)).removeAllTasks(any());
+ verify(mockTaskManager).addTask(taskCaptor.capture());
+ assertTrue(protocol.currentContextIsServer());
+ assertFalse(getBooleanField("isLeader"));
+ assertFalse(getBooleanField("isCandidate"));
+ assertEquals(4, getIntField("currentTerm"));
+ assertEquals(11, getIntField("leaderId"));
+ assertEquals(-1, getIntField("votedFor"));
+ assertEquals(4800L, taskCaptor.getValue().getTaskTime());
+ assertFalse(
+ ((VSProtocolScheduleEvent) taskCaptor.getValue().getEvent())
+ .isServerSchedule());
+ }
+
+ private void invokeBecomeFollower(int term, int leaderId) throws Exception {
+ Method method = VSRaftProtocol.class.getDeclaredMethod(
+ "becomeFollower", int.class, int.class);
+ method.setAccessible(true);
+ method.invoke(protocol, term, leaderId);
+ }
+
+ private int getIntField(String fieldName) throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ return field.getInt(protocol);
+ }
+
+ private boolean getBooleanField(String fieldName) throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ return field.getBoolean(protocol);
}
}