mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
8dc4d76890
<!-- If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting. --> relates to GRU-18 Adds basic implementation for Workspace Agent Connection Watch and tests. Missing are handling of logs.
102 lines
2.9 KiB
Go
102 lines
2.9 KiB
Go
package testutil
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
|
|
"cdr.dev/slog/v3"
|
|
)
|
|
|
|
// InMemWebsocketRoundTripper allows you to "dial" an HTTP handler that sets up a websocket using only in-memory
|
|
// primitives. No TCP or OS networking needed. CtxMutator gives you explicit control over the context the handler sees.
|
|
//
|
|
// Example:
|
|
//
|
|
// rt := testutil.InMemWebsocketRoundTripper{
|
|
// Handler: MyHandler,
|
|
// CtxMutator: func(ctx context.Context) context.Context {
|
|
// ctx = httpmw.WithWorkspaceParam(ctx, ws)
|
|
// ctx = dbauthz.As(ctx, coderdtest.MemberSubject(userID, orgID))
|
|
// return ctx
|
|
// },
|
|
// Logger: logger.Named("roundtripper"),
|
|
// }
|
|
// clientSock, _, err := websocket.Dial(ctx, "wss://local.test/", &websocket.DialOptions{
|
|
// HTTPClient: &http.Client{Transport: rt},
|
|
// })
|
|
type InMemWebsocketRoundTripper struct {
|
|
Logger slog.Logger
|
|
Handler http.Handler
|
|
CtxMutator func(ctx context.Context) context.Context
|
|
}
|
|
|
|
func (i InMemWebsocketRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
i.Logger.Debug(context.Background(), "round trip start")
|
|
defer i.Logger.Debug(context.Background(), "round trip end")
|
|
newCtx := i.CtxMutator(request.Context())
|
|
request = request.WithContext(newCtx)
|
|
serverP, clientP := net.Pipe()
|
|
var _ io.ReadWriteCloser = clientP // compile time check that response body is OK for websocket
|
|
response := &http.Response{
|
|
Header: make(http.Header),
|
|
Body: clientP,
|
|
}
|
|
rw := newInMemWebsocketResponseWriter(response, serverP)
|
|
go func() {
|
|
i.Handler.ServeHTTP(rw, request)
|
|
if !rw.hijacked {
|
|
i.Logger.Debug(context.Background(), "closing connection after handler did not hijack")
|
|
// If the handler didn't hijack the connection, we should close it when the handler finishes.
|
|
// This prevents a 3s delay in websocket.Dial() reading the non-upgraded response.
|
|
_ = serverP.Close()
|
|
}
|
|
}()
|
|
select {
|
|
case <-newCtx.Done():
|
|
return nil, newCtx.Err()
|
|
case <-rw.gotHeaders:
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
func newInMemWebsocketResponseWriter(resp *http.Response, conn net.Conn) *inMemWebsocketResponseWriter {
|
|
r := bufio.NewReader(conn)
|
|
w := bufio.NewWriter(conn)
|
|
return &inMemWebsocketResponseWriter{
|
|
r: resp,
|
|
b: bufio.NewReadWriter(r, w),
|
|
gotHeaders: make(chan struct{}),
|
|
conn: conn,
|
|
}
|
|
}
|
|
|
|
type inMemWebsocketResponseWriter struct {
|
|
r *http.Response
|
|
b *bufio.ReadWriter
|
|
gotHeaders chan struct{}
|
|
hijacked bool
|
|
conn net.Conn
|
|
}
|
|
|
|
func (rw *inMemWebsocketResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
rw.hijacked = true
|
|
return rw.conn, rw.b, nil
|
|
}
|
|
|
|
func (rw *inMemWebsocketResponseWriter) Header() http.Header {
|
|
return rw.r.Header
|
|
}
|
|
|
|
func (rw *inMemWebsocketResponseWriter) Write([]byte) (int, error) {
|
|
n, err := rw.b.Write([]byte{})
|
|
return n, err
|
|
}
|
|
|
|
func (rw *inMemWebsocketResponseWriter) WriteHeader(statusCode int) {
|
|
rw.r.StatusCode = statusCode
|
|
close(rw.gotHeaders)
|
|
}
|