From 4c19cd0e99cc0fb034a50787500a17a3cb3ab5b2 Mon Sep 17 00:00:00 2001 From: Eduard Urbach Date: Sat, 21 Jun 2025 11:37:39 +0200 Subject: [PATCH] Added support for io.ReadAll on requests --- Request.go | 42 +++++++++++++++++++++++++++++++++--------- Request_test.go | 32 +++++++++++++++++++++++++++----- Server.go | 8 ++++---- readme.md | 1 + 4 files changed, 65 insertions(+), 18 deletions(-) diff --git a/Request.go b/Request.go index 1b9c3d0..1230590 100644 --- a/Request.go +++ b/Request.go @@ -3,6 +3,8 @@ package web import ( "bufio" "io" + "strconv" + "strings" "git.urbach.dev/go/router" ) @@ -20,20 +22,22 @@ type Request interface { // request represents the HTTP request used in the given context. type request struct { - bufio.Reader - scheme string - host string - method string - path string - query string - headers []Header - params []router.Parameter + reader bufio.Reader + scheme string + host string + method string + path string + query string + headers []Header + params []router.Parameter + length int + consumed int } // Header returns the header value for the given key. func (req *request) Header(key string) string { for _, header := range req.headers { - if header.Key == key { + if strings.EqualFold(header.Key, key) { return header.Value } } @@ -69,6 +73,26 @@ func (req *request) Path() string { return req.path } +// Read implements the io.Reader interface. +func (req *request) Read(p []byte) (n int, err error) { + if req.length == 0 { + req.length, _ = strconv.Atoi(req.Header("Content-Length")) + + if req.length == 0 { + return 0, io.EOF + } + } + + n, err = req.reader.Read(p) + req.consumed += n + + if req.consumed < req.length { + return n, err + } + + return n - (req.consumed - req.length), io.EOF +} + // Scheme returns either `http`, `https` or an empty string. func (req request) Scheme() string { return req.scheme diff --git a/Request_test.go b/Request_test.go index b21a8bd..7774008 100644 --- a/Request_test.go +++ b/Request_test.go @@ -2,6 +2,8 @@ package web_test import ( "fmt" + "io" + "strconv" "strings" "testing" @@ -30,19 +32,39 @@ func TestRequestBody(t *testing.T) { s := web.NewServer() s.Get("/", func(ctx web.Context) error { - body := make([]byte, 4096) - n, err := ctx.Request().Read(body) + body, err := io.ReadAll(ctx.Request()) if err != nil { return err } - return ctx.Bytes(body[:n]) + return ctx.Bytes(body) }) - response := s.Request("GET", "/", nil, strings.NewReader("Hello")) + body := strings.Repeat("Hello", 1000) + headers := []web.Header{{Key: "Content-Length", Value: strconv.Itoa(len(body))}} + response := s.Request("GET", "/", headers, strings.NewReader(body)) assert.Equal(t, response.Status(), 200) - assert.Equal(t, string(response.Body()), "Hello") + assert.Equal(t, string(response.Body()), body) +} + +func TestRequestBodyMissingLength(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + body, err := io.ReadAll(ctx.Request()) + + if err != nil { + return err + } + + return ctx.Bytes(body) + }) + + body := strings.Repeat("Hello", 1000) + response := s.Request("GET", "/", nil, strings.NewReader(body)) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "") } func TestRequestHeader(t *testing.T) { diff --git a/Server.go b/Server.go index a851fc4..79fd92f 100644 --- a/Server.go +++ b/Server.go @@ -78,7 +78,7 @@ func (s *server) Ready() chan struct{} { func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response { ctx := s.newContext() ctx.request.headers = headers - ctx.request.Reader.Reset(body) + ctx.request.reader.Reset(body) s.handleRequest(ctx, method, url, io.Discard) return ctx.Response() } @@ -133,14 +133,14 @@ func (s *server) handleConnection(conn net.Conn) { close bool ) - ctx.Reader.Reset(conn) + ctx.reader.Reset(conn) defer conn.Close() defer s.contextPool.Put(ctx) for !close { // Read the HTTP request line - message, err := ctx.Reader.ReadString('\n') + message, err := ctx.reader.ReadString('\n') if err != nil { return @@ -177,7 +177,7 @@ func (s *server) handleConnection(conn net.Conn) { // Add headers until we meet an empty line for { - message, err = ctx.Reader.ReadString('\n') + message, err = ctx.reader.ReadString('\n') if err != nil { return diff --git a/readme.md b/readme.md index 66bfe4e..120252b 100644 --- a/readme.md +++ b/readme.md @@ -58,6 +58,7 @@ PASS: TestErrorMultiple PASS: TestRedirect PASS: TestRequest PASS: TestRequestBody +PASS: TestRequestBodyMissingLength PASS: TestRequestHeader PASS: TestRequestParam PASS: TestWrite