diff options
| -rw-r--r-- | internal/client/tcpclient.go | 11 | ||||
| -rw-r--r-- | internal/server/tcpserver.go | 17 | ||||
| -rw-r--r-- | internal/tcp/tcp.go | 44 | ||||
| -rw-r--r-- | internal/tcp/tcp_test.go | 63 |
4 files changed, 118 insertions, 17 deletions
diff --git a/internal/client/tcpclient.go b/internal/client/tcpclient.go index 1a76ad6..3181fd4 100644 --- a/internal/client/tcpclient.go +++ b/internal/client/tcpclient.go @@ -3,10 +3,10 @@ package client import ( "context" "fmt" - "io" "log" "net" + "codeberg.org/snonux/gorum/internal/tcp" "codeberg.org/snonux/gorum/internal/vote" ) @@ -29,18 +29,15 @@ func tcpClientRun(ctx context.Context, node string, ch <-chan vote.Vote) error { } log.Println("tcpclient: sending", message, "to node", node) - - bytes := []byte(fmt.Sprintf("%s\n", message)) - _, err = conn.Write(bytes) - if err != nil { + if err := tcp.WriteStr(conn, message); err != nil { return err } - response, err := io.ReadAll(conn) + response, err := tcp.ReadStr(conn) if err != nil { return err } - log.Println("tcpclient: received", string(response), "from node", node) + log.Println("tcpclient: received", response, "from node", node) } } diff --git a/internal/server/tcpserver.go b/internal/server/tcpserver.go index b6d2ede..21324c3 100644 --- a/internal/server/tcpserver.go +++ b/internal/server/tcpserver.go @@ -1,13 +1,13 @@ package server import ( - "bufio" "context" "fmt" "log" "net" "codeberg.org/snonux/gorum/internal/config" + "codeberg.org/snonux/gorum/internal/tcp" ) type handlerCb func(message string) string @@ -41,11 +41,7 @@ func tcpServerRun(ctx context.Context, conf config.Config, cb handlerCb) error { func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) { defer conn.Close() - - var ( - remoteAddr = conn.RemoteAddr().String() - reader = bufio.NewReader(conn) - ) + remoteAddr := conn.RemoteAddr().String() for { select { @@ -53,15 +49,16 @@ func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) { log.Println("server: context done, disconnecting client:", remoteAddr) return default: - message, err := reader.ReadString('\n') + message, err := tcp.ReadStr(conn) if err != nil { - log.Println("server: client disconnected:", remoteAddr, err) + log.Println("server: unable to read message", remoteAddr, err) return } - log.Println("server: received message", remoteAddr, message) + log.Println("server: received message", message, "from", remoteAddr) response := cb(message) - if _, err := conn.Write([]byte(response)); err != nil { + + if err := tcp.WriteStr(conn, response); err != nil { log.Println("error:", err) } } diff --git a/internal/tcp/tcp.go b/internal/tcp/tcp.go new file mode 100644 index 0000000..3f9bafc --- /dev/null +++ b/internal/tcp/tcp.go @@ -0,0 +1,44 @@ +package tcp + +import ( + "encoding/binary" + "io" +) + +type Writer interface { + Write(b []byte) (n int, err error) +} + +type Reader interface { + Read(b []byte) (n int, err error) +} + +func WriteStr(w Writer, message string) error { + messageBytes := []byte(message) + sizeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(sizeBytes, uint64(len(messageBytes))) + + if _, err := w.Write(sizeBytes); err != nil { + return err + } + if _, err := w.Write(messageBytes); err != nil { + return err + } + + return nil +} + +func ReadStr(r Reader) (string, error) { + sizeBytes := make([]byte, 8) + if _, err := io.ReadFull(r, sizeBytes); err != nil { + return "", err + } + messageSize := binary.BigEndian.Uint64(sizeBytes) + + messageBytes := make([]byte, messageSize) + if _, err := io.ReadFull(r, messageBytes); err != nil { + return "", err + } + + return string(messageBytes), nil +} diff --git a/internal/tcp/tcp_test.go b/internal/tcp/tcp_test.go new file mode 100644 index 0000000..9b4d61a --- /dev/null +++ b/internal/tcp/tcp_test.go @@ -0,0 +1,63 @@ +package tcp + +import ( + "testing" +) + +type readTest struct { + sizeWritten *bool + sizeRead *bool + sizeBytes []byte + messageBytes []byte +} + +func (rt readTest) Write(b []byte) (n int, err error) { + if !*rt.sizeWritten { + copy(rt.sizeBytes, b) + *rt.sizeWritten = true + } else { + copy(rt.messageBytes, b) + } + + return len(b), nil +} + +func (rt readTest) Read(b []byte) (n int, err error) { + if !*rt.sizeRead { + copy(b, rt.sizeBytes) + *rt.sizeRead = true + } else { + copy(b, rt.messageBytes) + } + return len(b), nil +} + +func TestReadWrite(t *testing.T) { + t.Parallel() + + message := "Hello world!" + + var sizeWritten bool + var sizeRead bool + + rt := readTest{ + sizeWritten: &sizeWritten, + sizeRead: &sizeRead, + sizeBytes: make([]byte, 8), + messageBytes: make([]byte, len([]byte(message))), + } + + if err := WriteStr(rt, message); err != nil { + t.Errorf(err.Error()) + } + + response, err := ReadStr(rt) + if err != nil { + t.Errorf(err.Error()) + } + + if response != message { + t.Errorf("Expected response '%s' to be equal to original message '%s'!", + response, message) + } +} |
