summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2023-06-18 13:40:26 +0300
committerPaul Buetow <paul@buetow.org>2023-06-18 13:40:26 +0300
commit93b3e63e89594b2b7a05b791a5cfc366f10c763b (patch)
treeec2b53bb30393481ce5d6c969bca62e577199ce0
parente3c903c49c798531124a1c59859f47d04d4534b9 (diff)
add universal tcp message reader and writer
-rw-r--r--internal/client/tcpclient.go11
-rw-r--r--internal/server/tcpserver.go17
-rw-r--r--internal/tcp/tcp.go44
-rw-r--r--internal/tcp/tcp_test.go63
4 files changed, 118 insertions, 17 deletions
diff --git a/internal/client/tcpclient.go b/internal/client/tcpclient.go
index 1a76ad6..3181fd4 100644
--- a/internal/client/tcpclient.go
+++ b/internal/client/tcpclient.go
@@ -3,10 +3,10 @@ package client
import (
"context"
"fmt"
- "io"
"log"
"net"
+ "codeberg.org/snonux/gorum/internal/tcp"
"codeberg.org/snonux/gorum/internal/vote"
)
@@ -29,18 +29,15 @@ func tcpClientRun(ctx context.Context, node string, ch <-chan vote.Vote) error {
}
log.Println("tcpclient: sending", message, "to node", node)
-
- bytes := []byte(fmt.Sprintf("%s\n", message))
- _, err = conn.Write(bytes)
- if err != nil {
+ if err := tcp.WriteStr(conn, message); err != nil {
return err
}
- response, err := io.ReadAll(conn)
+ response, err := tcp.ReadStr(conn)
if err != nil {
return err
}
- log.Println("tcpclient: received", string(response), "from node", node)
+ log.Println("tcpclient: received", response, "from node", node)
}
}
diff --git a/internal/server/tcpserver.go b/internal/server/tcpserver.go
index b6d2ede..21324c3 100644
--- a/internal/server/tcpserver.go
+++ b/internal/server/tcpserver.go
@@ -1,13 +1,13 @@
package server
import (
- "bufio"
"context"
"fmt"
"log"
"net"
"codeberg.org/snonux/gorum/internal/config"
+ "codeberg.org/snonux/gorum/internal/tcp"
)
type handlerCb func(message string) string
@@ -41,11 +41,7 @@ func tcpServerRun(ctx context.Context, conf config.Config, cb handlerCb) error {
func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) {
defer conn.Close()
-
- var (
- remoteAddr = conn.RemoteAddr().String()
- reader = bufio.NewReader(conn)
- )
+ remoteAddr := conn.RemoteAddr().String()
for {
select {
@@ -53,15 +49,16 @@ func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) {
log.Println("server: context done, disconnecting client:", remoteAddr)
return
default:
- message, err := reader.ReadString('\n')
+ message, err := tcp.ReadStr(conn)
if err != nil {
- log.Println("server: client disconnected:", remoteAddr, err)
+ log.Println("server: unable to read message", remoteAddr, err)
return
}
- log.Println("server: received message", remoteAddr, message)
+ log.Println("server: received message", message, "from", remoteAddr)
response := cb(message)
- if _, err := conn.Write([]byte(response)); err != nil {
+
+ if err := tcp.WriteStr(conn, response); err != nil {
log.Println("error:", err)
}
}
diff --git a/internal/tcp/tcp.go b/internal/tcp/tcp.go
new file mode 100644
index 0000000..3f9bafc
--- /dev/null
+++ b/internal/tcp/tcp.go
@@ -0,0 +1,44 @@
+package tcp
+
+import (
+ "encoding/binary"
+ "io"
+)
+
+type Writer interface {
+ Write(b []byte) (n int, err error)
+}
+
+type Reader interface {
+ Read(b []byte) (n int, err error)
+}
+
+func WriteStr(w Writer, message string) error {
+ messageBytes := []byte(message)
+ sizeBytes := make([]byte, 8)
+ binary.BigEndian.PutUint64(sizeBytes, uint64(len(messageBytes)))
+
+ if _, err := w.Write(sizeBytes); err != nil {
+ return err
+ }
+ if _, err := w.Write(messageBytes); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func ReadStr(r Reader) (string, error) {
+ sizeBytes := make([]byte, 8)
+ if _, err := io.ReadFull(r, sizeBytes); err != nil {
+ return "", err
+ }
+ messageSize := binary.BigEndian.Uint64(sizeBytes)
+
+ messageBytes := make([]byte, messageSize)
+ if _, err := io.ReadFull(r, messageBytes); err != nil {
+ return "", err
+ }
+
+ return string(messageBytes), nil
+}
diff --git a/internal/tcp/tcp_test.go b/internal/tcp/tcp_test.go
new file mode 100644
index 0000000..9b4d61a
--- /dev/null
+++ b/internal/tcp/tcp_test.go
@@ -0,0 +1,63 @@
+package tcp
+
+import (
+ "testing"
+)
+
+type readTest struct {
+ sizeWritten *bool
+ sizeRead *bool
+ sizeBytes []byte
+ messageBytes []byte
+}
+
+func (rt readTest) Write(b []byte) (n int, err error) {
+ if !*rt.sizeWritten {
+ copy(rt.sizeBytes, b)
+ *rt.sizeWritten = true
+ } else {
+ copy(rt.messageBytes, b)
+ }
+
+ return len(b), nil
+}
+
+func (rt readTest) Read(b []byte) (n int, err error) {
+ if !*rt.sizeRead {
+ copy(b, rt.sizeBytes)
+ *rt.sizeRead = true
+ } else {
+ copy(b, rt.messageBytes)
+ }
+ return len(b), nil
+}
+
+func TestReadWrite(t *testing.T) {
+ t.Parallel()
+
+ message := "Hello world!"
+
+ var sizeWritten bool
+ var sizeRead bool
+
+ rt := readTest{
+ sizeWritten: &sizeWritten,
+ sizeRead: &sizeRead,
+ sizeBytes: make([]byte, 8),
+ messageBytes: make([]byte, len([]byte(message))),
+ }
+
+ if err := WriteStr(rt, message); err != nil {
+ t.Errorf(err.Error())
+ }
+
+ response, err := ReadStr(rt)
+ if err != nil {
+ t.Errorf(err.Error())
+ }
+
+ if response != message {
+ t.Errorf("Expected response '%s' to be equal to original message '%s'!",
+ response, message)
+ }
+}