mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
4aa94fcd4c
Add Unwrap() to StatusWriter so http.ResponseController.SetWriteDeadline can reach the underlying net.Conn through the middleware wrapper. Without this, the agent's 20s WriteTimeout killed blocking process output connections. Also add 30s headroom to the write deadline in handleProcessOutput so the response can be written after a full-duration blocking wait. On the tool layer, waitForProcess and the process_output tool now try a non-blocking snapshot on any error, not just context timeout. Transport errors (like the WriteTimeout EOF) previously returned with no process ID and no recovery path. Now if the process finished, the result is returned transparently. If still running, the error includes the process ID and tells the agent to use process_output.
137 lines
3.2 KiB
Go
137 lines
3.2 KiB
Go
package tracing
|
|
|
|
import (
|
|
"bufio"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"runtime"
|
|
"strings"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/buildinfo"
|
|
)
|
|
|
|
var (
|
|
_ http.ResponseWriter = (*StatusWriter)(nil)
|
|
_ http.Hijacker = (*StatusWriter)(nil)
|
|
)
|
|
|
|
// StatusWriter intercepts the status of the request and the response body up
|
|
// to maxBodySize if Status >= 400. It is guaranteed to be the ResponseWriter
|
|
// directly downstream from Middleware.
|
|
type StatusWriter struct {
|
|
http.ResponseWriter
|
|
Status int
|
|
Hijacked bool
|
|
responseBody []byte
|
|
|
|
wroteHeader bool
|
|
wroteHeaderStack string
|
|
}
|
|
|
|
func StatusWriterMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
sw := &StatusWriter{ResponseWriter: rw}
|
|
next.ServeHTTP(sw, r)
|
|
})
|
|
}
|
|
|
|
func (w *StatusWriter) WriteHeader(status int) {
|
|
if buildinfo.IsDev() || flag.Lookup("test.v") != nil {
|
|
if w.wroteHeader {
|
|
stack := getStackString(2)
|
|
wroteHeaderStack := w.wroteHeaderStack
|
|
if wroteHeaderStack == "" {
|
|
wroteHeaderStack = "unknown"
|
|
}
|
|
// It's fine that this logs to stdlib logger since it only happens
|
|
// in dev builds and tests.
|
|
log.Printf("duplicate call to (*StatusWriter.).WriteHeader(%d):\n\nstack: %s\n\nheader written at: %s", status, stack, wroteHeaderStack)
|
|
} else {
|
|
w.wroteHeaderStack = getStackString(2)
|
|
}
|
|
}
|
|
if !w.wroteHeader {
|
|
w.Status = status
|
|
w.wroteHeader = true
|
|
}
|
|
w.ResponseWriter.WriteHeader(status)
|
|
}
|
|
|
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
|
const maxBodySize = 4096
|
|
|
|
if !w.wroteHeader {
|
|
w.Status = http.StatusOK
|
|
w.wroteHeader = true
|
|
}
|
|
|
|
if w.Status >= http.StatusBadRequest {
|
|
// This is technically wrong as multiple calls to write
|
|
// will simply overwrite w.ResponseBody but given that
|
|
// we typically only write to the response body once
|
|
// and this field is only used for logging I'm leaving
|
|
// this as-is.
|
|
w.responseBody = make([]byte, minInt(len(b), maxBodySize))
|
|
copy(w.responseBody, b)
|
|
}
|
|
|
|
return w.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func minInt(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Unwrap returns the underlying ResponseWriter, allowing
|
|
// http.ResponseController to reach it for SetWriteDeadline, etc.
|
|
func (w *StatusWriter) Unwrap() http.ResponseWriter {
|
|
return w.ResponseWriter
|
|
}
|
|
|
|
func (w *StatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
|
if !ok {
|
|
return nil, nil, xerrors.Errorf("%T is not a http.Hijacker", w.ResponseWriter)
|
|
}
|
|
w.Hijacked = true
|
|
|
|
return hijacker.Hijack()
|
|
}
|
|
|
|
func (w *StatusWriter) ResponseBody() []byte {
|
|
return w.responseBody
|
|
}
|
|
|
|
func (w *StatusWriter) Flush() {
|
|
f, ok := w.ResponseWriter.(http.Flusher)
|
|
if !ok {
|
|
panic("http.ResponseWriter is not http.Flusher")
|
|
}
|
|
f.Flush()
|
|
}
|
|
|
|
func getStackString(skip int) string {
|
|
// Get up to 5 callers, skipping this one and the skip count.
|
|
pcs := make([]uintptr, 5)
|
|
got := runtime.Callers(skip+1, pcs)
|
|
frames := runtime.CallersFrames(pcs[:got])
|
|
|
|
callers := []string{}
|
|
for {
|
|
frame, more := frames.Next()
|
|
callers = append(callers, fmt.Sprintf("%s:%v", frame.File, frame.Line))
|
|
if !more {
|
|
break
|
|
}
|
|
}
|
|
return strings.Join(callers, " -> ")
|
|
}
|