summaryrefslogtreecommitdiff
path: root/internal/hexaicli/simulation.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/simulation.go')
-rw-r--r--internal/hexaicli/simulation.go200
1 files changed, 200 insertions, 0 deletions
diff --git a/internal/hexaicli/simulation.go b/internal/hexaicli/simulation.go
new file mode 100644
index 0000000..fbd7da9
--- /dev/null
+++ b/internal/hexaicli/simulation.go
@@ -0,0 +1,200 @@
+package hexaicli
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "math"
+ "math/rand"
+ "strconv"
+ "strings"
+ "time"
+ "unicode"
+ "unicode/utf8"
+)
+
+type simulationContextKey struct{}
+
+type tpsSimulationSpec struct {
+ min float64
+ max float64
+}
+
+const defaultSimulationText = "Hexai TPS simulation mode is emitting placeholder output so you can gauge how responsive a future model might feel on your hardware. Pipe a file into stdin to preview that exact text at the configured output speed instead."
+
+// WithCLITPSSimulation returns a context that carries the CLI TPS simulation range.
+func WithCLITPSSimulation(ctx context.Context, value string) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return context.WithValue(ctx, simulationContextKey{}, strings.TrimSpace(value))
+}
+
+func tpsSimulationFromContext(ctx context.Context) (tpsSimulationSpec, bool, error) {
+ if ctx == nil {
+ return tpsSimulationSpec{}, false, nil
+ }
+ value, ok := ctx.Value(simulationContextKey{}).(string)
+ if !ok || strings.TrimSpace(value) == "" {
+ return tpsSimulationSpec{}, false, nil
+ }
+ spec, err := parseTPSSimulation(value)
+ if err != nil {
+ return tpsSimulationSpec{}, true, err
+ }
+ return spec, true, nil
+}
+
+func parseTPSSimulation(raw string) (tpsSimulationSpec, error) {
+ value := strings.TrimSpace(raw)
+ if value == "" {
+ return tpsSimulationSpec{}, fmt.Errorf("hexai: --tps-simulation expects <tps> or <min>-<max>")
+ }
+ if strings.Count(value, "-") == 1 && !strings.HasPrefix(value, "-") {
+ return parseTPSSimulationRange(value, "-")
+ }
+ if strings.Count(value, ":") == 1 {
+ return parseTPSSimulationRange(value, ":")
+ }
+ tps, err := parsePositiveTPS(value)
+ if err != nil {
+ return tpsSimulationSpec{}, err
+ }
+ return tpsSimulationSpec{min: tps, max: tps}, nil
+}
+
+func parseTPSSimulationRange(value string, sep string) (tpsSimulationSpec, error) {
+ left, right, ok := strings.Cut(value, sep)
+ if !ok {
+ return tpsSimulationSpec{}, fmt.Errorf("hexai: invalid --tps-simulation value %q", value)
+ }
+ minTPS, err := parsePositiveTPS(left)
+ if err != nil {
+ return tpsSimulationSpec{}, err
+ }
+ maxTPS, err := parsePositiveTPS(right)
+ if err != nil {
+ return tpsSimulationSpec{}, err
+ }
+ if minTPS > maxTPS {
+ return tpsSimulationSpec{}, fmt.Errorf("hexai: --tps-simulation minimum %.2f exceeds maximum %.2f", minTPS, maxTPS)
+ }
+ return tpsSimulationSpec{min: minTPS, max: maxTPS}, nil
+}
+
+func parsePositiveTPS(raw string) (float64, error) {
+ value := strings.TrimSpace(raw)
+ tps, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return 0, fmt.Errorf("hexai: invalid --tps-simulation value %q", value)
+ }
+ if tps <= 0 {
+ return 0, fmt.Errorf("hexai: --tps-simulation requires a positive value, got %q", value)
+ }
+ return tps, nil
+}
+
+func readSimulationInput(stdin io.Reader, args []string) (string, error) {
+ input, err := readInput(stdin, args)
+ if err == nil {
+ return input, nil
+ }
+ if strings.Contains(err.Error(), "no input provided") {
+ return defaultSimulationText, nil
+ }
+ return "", err
+}
+
+func runTPSSimulation(ctx context.Context, spec tpsSimulationSpec, input string, out io.Writer) error {
+ chunks := splitSimulationChunks(input)
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ for i, chunk := range chunks {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ if _, err := io.WriteString(out, chunk); err != nil {
+ return err
+ }
+ if i == len(chunks)-1 {
+ continue
+ }
+ if err := sleepWithContext(ctx, simulationDelay(spec, chunk, rng)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func simulationDelay(spec tpsSimulationSpec, chunk string, rng *rand.Rand) time.Duration {
+ tokens := estimateSimulationTokens(chunk)
+ if tokens == 0 {
+ return 0
+ }
+ tps := spec.min
+ if spec.max > spec.min {
+ tps += rng.Float64() * (spec.max - spec.min)
+ }
+ seconds := float64(tokens) / tps
+ return time.Duration(seconds * float64(time.Second))
+}
+
+func splitSimulationChunks(input string) []string {
+ if input == "" {
+ return nil
+ }
+ chunks := make([]string, 0, strings.Count(input, " ")+1)
+ start := 0
+ sawWord := false
+ for i, r := range input {
+ if unicode.IsSpace(r) {
+ if !sawWord {
+ continue
+ }
+ end := advanceWhitespace(input, i)
+ chunks = append(chunks, input[start:end])
+ start = end
+ sawWord = false
+ continue
+ }
+ sawWord = true
+ }
+ if start < len(input) {
+ chunks = append(chunks, input[start:])
+ }
+ return chunks
+}
+
+func advanceWhitespace(input string, start int) int {
+ end := start
+ for end < len(input) {
+ r, size := utf8.DecodeRuneInString(input[end:])
+ if !unicode.IsSpace(r) {
+ break
+ }
+ end += size
+ }
+ return end
+}
+
+func estimateSimulationTokens(chunk string) int {
+ trimmed := strings.TrimSpace(chunk)
+ if trimmed == "" {
+ return 0
+ }
+ runes := utf8.RuneCountInString(trimmed)
+ return max(1, int(math.Ceil(float64(runes)/4.0)))
+}
+
+func sleepWithContext(ctx context.Context, delay time.Duration) error {
+ if delay <= 0 {
+ return ctx.Err()
+ }
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}