diff options
Diffstat (limited to 'internal/flamegraph/liveserver.go')
| -rw-r--r-- | internal/flamegraph/liveserver.go | 107 |
1 files changed, 105 insertions, 2 deletions
diff --git a/internal/flamegraph/liveserver.go b/internal/flamegraph/liveserver.go index de65ee3..4cc5629 100644 --- a/internal/flamegraph/liveserver.go +++ b/internal/flamegraph/liveserver.go @@ -6,8 +6,13 @@ import ( "errors" "fmt" "net/http" + "os" "os/exec" + "os/user" + "path/filepath" + "strconv" "strings" + "syscall" "time" ) @@ -19,6 +24,7 @@ var liveServerTimeouts = serverTimeouts{ type LiveServerOptions struct { OpenCommand string + WarningCb func(message string) } var openBrowserURLFn = openBrowserURL @@ -39,7 +45,7 @@ func ServeLiveWithOptions(ctx context.Context, lt *LiveTrie, interval time.Durat url := fmt.Sprintf("http://%s:%d/", hostname, port) fmt.Printf("Live flamegraph available at %s\n", url) if err := maybeOpenLiveBrowser(url, options); err != nil { - fmt.Printf("Live flamegraph browser auto-open failed: %v\n", err) + notifyLiveWarning(options.WarningCb, fmt.Sprintf("Live flamegraph browser auto-open failed: %v", err)) } }) } @@ -57,13 +63,110 @@ func openBrowserURL(url, openCommand string) error { return err } cmd := exec.Command(parts[0], parts[1:]...) + applySudoInvokerContext(cmd) if err := cmd.Start(); err != nil { return err } - go func() { _ = cmd.Wait() }() + + waitCh := make(chan error, 1) + go func() { waitCh <- cmd.Wait() }() + + select { + case waitErr := <-waitCh: + if waitErr != nil { + return fmt.Errorf("browser command exited early: %w", waitErr) + } + case <-time.After(750 * time.Millisecond): + } return nil } +func notifyLiveWarning(warningCb func(string), message string) { + if message == "" { + return + } + if warningCb != nil { + warningCb(message) + return + } + fmt.Println(message) +} + +func applySudoInvokerContext(cmd *exec.Cmd) { + applySudoInvokerContextWithEnv(cmd, os.Geteuid(), os.Environ()) +} + +func applySudoInvokerContextWithEnv(cmd *exec.Cmd, euid int, env []string) { + if cmd == nil || euid != 0 { + return + } + + sudoUIDStr, okUID := lookupEnvValue(env, "SUDO_UID") + sudoGIDStr, okGID := lookupEnvValue(env, "SUDO_GID") + if !okUID || !okGID { + return + } + + uid, errUID := strconv.ParseUint(strings.TrimSpace(sudoUIDStr), 10, 32) + gid, errGID := strconv.ParseUint(strings.TrimSpace(sudoGIDStr), 10, 32) + if errUID != nil || errGID != nil { + return + } + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(gid), + }, + } + + launchEnv := append([]string(nil), env...) + if sudoUser, ok := lookupEnvValue(env, "SUDO_USER"); ok && strings.TrimSpace(sudoUser) != "" { + launchEnv = upsertEnvValue(launchEnv, "USER", sudoUser) + launchEnv = upsertEnvValue(launchEnv, "LOGNAME", sudoUser) + } + + if sudoUser, err := user.LookupId(strconv.FormatUint(uid, 10)); err == nil && strings.TrimSpace(sudoUser.HomeDir) != "" { + launchEnv = upsertEnvValue(launchEnv, "HOME", sudoUser.HomeDir) + if _, ok := lookupEnvValue(launchEnv, "XAUTHORITY"); !ok { + xauth := filepath.Join(sudoUser.HomeDir, ".Xauthority") + if info, statErr := os.Stat(xauth); statErr == nil && !info.IsDir() { + launchEnv = upsertEnvValue(launchEnv, "XAUTHORITY", xauth) + } + } + } + + if _, ok := lookupEnvValue(launchEnv, "XDG_RUNTIME_DIR"); !ok { + runtimeDir := fmt.Sprintf("/run/user/%d", uid) + if info, statErr := os.Stat(runtimeDir); statErr == nil && info.IsDir() { + launchEnv = upsertEnvValue(launchEnv, "XDG_RUNTIME_DIR", runtimeDir) + } + } + + cmd.Env = launchEnv +} + +func lookupEnvValue(env []string, key string) (string, bool) { + prefix := key + "=" + for _, entry := range env { + if strings.HasPrefix(entry, prefix) { + return strings.TrimPrefix(entry, prefix), true + } + } + return "", false +} + +func upsertEnvValue(env []string, key, value string) []string { + prefix := key + "=" + for i := range env { + if strings.HasPrefix(env[i], prefix) { + env[i] = prefix + value + return env + } + } + return append(env, prefix+value) +} + func browserOpenCommandParts(openCommand, url string) ([]string, error) { parts := strings.Fields(strings.TrimSpace(openCommand)) if len(parts) == 0 { |
