summaryrefslogtreecommitdiff
path: root/internal/lsp/transport_concurrency_test.go
blob: f2390c71b558c9218b9b8c04f210e140b1860bf0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package lsp

import (
	"bytes"
	"encoding/json"
	"io"
	"log"
	"runtime"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
)

type raceBuffer struct {
	buf     bytes.Buffer
	bufMu   sync.Mutex
	inWrite atomic.Int32
	raced   atomic.Bool
}

func (b *raceBuffer) Write(p []byte) (int, error) {
	if b.inWrite.Swap(1) != 0 {
		b.raced.Store(true)
	}
	runtime.Gosched()
	b.bufMu.Lock()
	n, err := b.buf.Write(p)
	b.bufMu.Unlock()
	b.inWrite.Store(0)
	return n, err
}

func (b *raceBuffer) Bytes() []byte {
	b.bufMu.Lock()
	defer b.bufMu.Unlock()
	return append([]byte(nil), b.buf.Bytes()...)
}

func (b *raceBuffer) Raced() bool {
	return b.raced.Load()
}

func TestServerReplySerializesWrites(t *testing.T) {
	t.Parallel()

	writer := &raceBuffer{}
	srv := NewServer(bytes.NewReader([]byte{}), writer, log.New(io.Discard, "", 0), ServerOptions{})

	const goroutines = 16
	start := make(chan struct{})
	var wg sync.WaitGroup
	for i := 0; i < goroutines; i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			<-start
			id := json.RawMessage(strconv.Itoa(i + 1))
			srv.reply(id, map[string]int{"index": i}, nil)
		}(i)
	}
	close(start)
	wg.Wait()

	if writer.Raced() {
		t.Fatalf("detected overlapping writes to server output")
	}

	data := writer.Bytes()
	for len(data) > 0 {
		headerEnd := bytes.Index(data, []byte("\r\n\r\n"))
		if headerEnd < 0 {
			t.Fatalf("missing header delimiter in %q", string(data))
		}
		header := string(data[:headerEnd])
		if !strings.HasPrefix(header, "Content-Length: ") {
			t.Fatalf("unexpected header %q", header)
		}
		lengthStr := strings.TrimSpace(header[len("Content-Length: "):])
		length, err := strconv.Atoi(lengthStr)
		if err != nil {
			t.Fatalf("invalid content length %q: %v", lengthStr, err)
		}
		payloadStart := headerEnd + 4
		payloadEnd := payloadStart + length
		if payloadEnd > len(data) {
			t.Fatalf("payload truncated: need %d bytes, have %d", length, len(data)-payloadStart)
		}
		data = data[payloadEnd:]
	}
}