Added support for io.ReadAll on requests
All checks were successful
/ test (push) Successful in 18s

This commit is contained in:
Eduard Urbach 2025-06-21 11:37:39 +02:00
parent 9a781d2e64
commit 4c19cd0e99
Signed by: eduard
GPG key ID: 49226B848C78F6C8
4 changed files with 65 additions and 18 deletions

View file

@ -3,6 +3,8 @@ package web
import ( import (
"bufio" "bufio"
"io" "io"
"strconv"
"strings"
"git.urbach.dev/go/router" "git.urbach.dev/go/router"
) )
@ -20,20 +22,22 @@ type Request interface {
// request represents the HTTP request used in the given context. // request represents the HTTP request used in the given context.
type request struct { type request struct {
bufio.Reader reader bufio.Reader
scheme string scheme string
host string host string
method string method string
path string path string
query string query string
headers []Header headers []Header
params []router.Parameter params []router.Parameter
length int
consumed int
} }
// Header returns the header value for the given key. // Header returns the header value for the given key.
func (req *request) Header(key string) string { func (req *request) Header(key string) string {
for _, header := range req.headers { for _, header := range req.headers {
if header.Key == key { if strings.EqualFold(header.Key, key) {
return header.Value return header.Value
} }
} }
@ -69,6 +73,26 @@ func (req *request) Path() string {
return req.path return req.path
} }
// Read implements the io.Reader interface.
func (req *request) Read(p []byte) (n int, err error) {
if req.length == 0 {
req.length, _ = strconv.Atoi(req.Header("Content-Length"))
if req.length == 0 {
return 0, io.EOF
}
}
n, err = req.reader.Read(p)
req.consumed += n
if req.consumed < req.length {
return n, err
}
return n - (req.consumed - req.length), io.EOF
}
// Scheme returns either `http`, `https` or an empty string. // Scheme returns either `http`, `https` or an empty string.
func (req request) Scheme() string { func (req request) Scheme() string {
return req.scheme return req.scheme

View file

@ -2,6 +2,8 @@ package web_test
import ( import (
"fmt" "fmt"
"io"
"strconv"
"strings" "strings"
"testing" "testing"
@ -30,19 +32,39 @@ func TestRequestBody(t *testing.T) {
s := web.NewServer() s := web.NewServer()
s.Get("/", func(ctx web.Context) error { s.Get("/", func(ctx web.Context) error {
body := make([]byte, 4096) body, err := io.ReadAll(ctx.Request())
n, err := ctx.Request().Read(body)
if err != nil { if err != nil {
return err return err
} }
return ctx.Bytes(body[:n]) return ctx.Bytes(body)
}) })
response := s.Request("GET", "/", nil, strings.NewReader("Hello")) body := strings.Repeat("Hello", 1000)
headers := []web.Header{{Key: "Content-Length", Value: strconv.Itoa(len(body))}}
response := s.Request("GET", "/", headers, strings.NewReader(body))
assert.Equal(t, response.Status(), 200) assert.Equal(t, response.Status(), 200)
assert.Equal(t, string(response.Body()), "Hello") assert.Equal(t, string(response.Body()), body)
}
func TestRequestBodyMissingLength(t *testing.T) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
body, err := io.ReadAll(ctx.Request())
if err != nil {
return err
}
return ctx.Bytes(body)
})
body := strings.Repeat("Hello", 1000)
response := s.Request("GET", "/", nil, strings.NewReader(body))
assert.Equal(t, response.Status(), 200)
assert.Equal(t, string(response.Body()), "")
} }
func TestRequestHeader(t *testing.T) { func TestRequestHeader(t *testing.T) {

View file

@ -78,7 +78,7 @@ func (s *server) Ready() chan struct{} {
func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response { func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response {
ctx := s.newContext() ctx := s.newContext()
ctx.request.headers = headers ctx.request.headers = headers
ctx.request.Reader.Reset(body) ctx.request.reader.Reset(body)
s.handleRequest(ctx, method, url, io.Discard) s.handleRequest(ctx, method, url, io.Discard)
return ctx.Response() return ctx.Response()
} }
@ -133,14 +133,14 @@ func (s *server) handleConnection(conn net.Conn) {
close bool close bool
) )
ctx.Reader.Reset(conn) ctx.reader.Reset(conn)
defer conn.Close() defer conn.Close()
defer s.contextPool.Put(ctx) defer s.contextPool.Put(ctx)
for !close { for !close {
// Read the HTTP request line // Read the HTTP request line
message, err := ctx.Reader.ReadString('\n') message, err := ctx.reader.ReadString('\n')
if err != nil { if err != nil {
return return
@ -177,7 +177,7 @@ func (s *server) handleConnection(conn net.Conn) {
// Add headers until we meet an empty line // Add headers until we meet an empty line
for { for {
message, err = ctx.Reader.ReadString('\n') message, err = ctx.reader.ReadString('\n')
if err != nil { if err != nil {
return return

View file

@ -58,6 +58,7 @@ PASS: TestErrorMultiple
PASS: TestRedirect PASS: TestRedirect
PASS: TestRequest PASS: TestRequest
PASS: TestRequestBody PASS: TestRequestBody
PASS: TestRequestBodyMissingLength
PASS: TestRequestHeader PASS: TestRequestHeader
PASS: TestRequestParam PASS: TestRequestParam
PASS: TestWrite PASS: TestWrite