Added support for io.ReadAll on requests
All checks were successful
/ test (push) Successful in 18s

This commit is contained in:
Eduard Urbach 2025-06-21 11:37:39 +02:00
parent 9a781d2e64
commit 4c19cd0e99
Signed by: eduard
GPG key ID: 49226B848C78F6C8
4 changed files with 65 additions and 18 deletions

View file

@ -3,6 +3,8 @@ package web
import (
"bufio"
"io"
"strconv"
"strings"
"git.urbach.dev/go/router"
)
@ -20,7 +22,7 @@ type Request interface {
// request represents the HTTP request used in the given context.
type request struct {
bufio.Reader
reader bufio.Reader
scheme string
host string
method string
@ -28,12 +30,14 @@ type request struct {
query string
headers []Header
params []router.Parameter
length int
consumed int
}
// Header returns the header value for the given key.
func (req *request) Header(key string) string {
for _, header := range req.headers {
if header.Key == key {
if strings.EqualFold(header.Key, key) {
return header.Value
}
}
@ -69,6 +73,26 @@ func (req *request) Path() string {
return req.path
}
// Read implements the io.Reader interface.
func (req *request) Read(p []byte) (n int, err error) {
if req.length == 0 {
req.length, _ = strconv.Atoi(req.Header("Content-Length"))
if req.length == 0 {
return 0, io.EOF
}
}
n, err = req.reader.Read(p)
req.consumed += n
if req.consumed < req.length {
return n, err
}
return n - (req.consumed - req.length), io.EOF
}
// Scheme returns either `http`, `https` or an empty string.
func (req request) Scheme() string {
return req.scheme

View file

@ -2,6 +2,8 @@ package web_test
import (
"fmt"
"io"
"strconv"
"strings"
"testing"
@ -30,19 +32,39 @@ func TestRequestBody(t *testing.T) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
body := make([]byte, 4096)
n, err := ctx.Request().Read(body)
body, err := io.ReadAll(ctx.Request())
if err != nil {
return err
}
return ctx.Bytes(body[:n])
return ctx.Bytes(body)
})
response := s.Request("GET", "/", nil, strings.NewReader("Hello"))
body := strings.Repeat("Hello", 1000)
headers := []web.Header{{Key: "Content-Length", Value: strconv.Itoa(len(body))}}
response := s.Request("GET", "/", headers, strings.NewReader(body))
assert.Equal(t, response.Status(), 200)
assert.Equal(t, string(response.Body()), "Hello")
assert.Equal(t, string(response.Body()), body)
}
func TestRequestBodyMissingLength(t *testing.T) {
s := web.NewServer()
s.Get("/", func(ctx web.Context) error {
body, err := io.ReadAll(ctx.Request())
if err != nil {
return err
}
return ctx.Bytes(body)
})
body := strings.Repeat("Hello", 1000)
response := s.Request("GET", "/", nil, strings.NewReader(body))
assert.Equal(t, response.Status(), 200)
assert.Equal(t, string(response.Body()), "")
}
func TestRequestHeader(t *testing.T) {

View file

@ -78,7 +78,7 @@ func (s *server) Ready() chan struct{} {
func (s *server) Request(method string, url string, headers []Header, body io.Reader) Response {
ctx := s.newContext()
ctx.request.headers = headers
ctx.request.Reader.Reset(body)
ctx.request.reader.Reset(body)
s.handleRequest(ctx, method, url, io.Discard)
return ctx.Response()
}
@ -133,14 +133,14 @@ func (s *server) handleConnection(conn net.Conn) {
close bool
)
ctx.Reader.Reset(conn)
ctx.reader.Reset(conn)
defer conn.Close()
defer s.contextPool.Put(ctx)
for !close {
// Read the HTTP request line
message, err := ctx.Reader.ReadString('\n')
message, err := ctx.reader.ReadString('\n')
if err != nil {
return
@ -177,7 +177,7 @@ func (s *server) handleConnection(conn net.Conn) {
// Add headers until we meet an empty line
for {
message, err = ctx.Reader.ReadString('\n')
message, err = ctx.reader.ReadString('\n')
if err != nil {
return

View file

@ -58,6 +58,7 @@ PASS: TestErrorMultiple
PASS: TestRedirect
PASS: TestRequest
PASS: TestRequestBody
PASS: TestRequestBodyMissingLength
PASS: TestRequestHeader
PASS: TestRequestParam
PASS: TestWrite