From 5af7efcdbdcecf1295f18cf214d7847bb047c599 Mon Sep 17 00:00:00 2001 From: Eduard Urbach Date: Thu, 28 Mar 2024 14:27:40 +0100 Subject: [PATCH] Added more tests --- README.md | 7 ++- Server.go | 57 ++++++++++++------------ Server_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 149 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 515dccd..bf2568b 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,13 @@ PASS: TestResponseHeader PASS: TestResponseHeaderOverwrite PASS: TestPanic PASS: TestRun +PASS: TestBadRequest +PASS: TestBadRequestHeader +PASS: TestBadRequestMethod +PASS: TestBadRequestProtocol +PASS: TestEarlyClose PASS: TestUnavailablePort -coverage: 95.5% of statements +coverage: 100.0% of statements ``` ## Benchmarks diff --git a/Server.go b/Server.go index 51e8130..9ac676b 100644 --- a/Server.go +++ b/Server.go @@ -128,39 +128,38 @@ func (s *server) handleConnection(conn net.Conn) { defer s.contextPool.Put(ctx) for { - // Search for a line containing HTTP method and url - for { - message, err := ctx.reader.ReadString('\n') + // Read the HTTP request line + message, err := ctx.reader.ReadString('\n') - if err != nil { - return - } - - space := strings.IndexByte(message, ' ') - - if space <= 0 { - continue - } - - method = message[:space] - - if !isRequestMethod(method) { - continue - } - - lastSpace := strings.LastIndexByte(message, ' ') - - if lastSpace == -1 { - lastSpace = len(message) - } - - url = message[space+1 : lastSpace] - break + if err != nil { + return } + space := strings.IndexByte(message, ' ') + + if space <= 0 { + fmt.Fprint(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + method = message[:space] + + if !isRequestMethod(method) { + fmt.Fprint(conn, "HTTP/1.1 400 Bad Request\r\n\r\n") + return + } + + lastSpace := strings.LastIndexByte(message, ' ') + + if lastSpace == space { + lastSpace = len(message) - len("\r\n") + } + + url = message[space+1 : lastSpace] + // Add headers until we meet an empty line for { - message, err := ctx.reader.ReadString('\n') + message, err = ctx.reader.ReadString('\n') if err != nil { return @@ -210,7 +209,7 @@ func (s *server) handleRequest(ctx *context, method string, url string, writer i s.errorHandler(ctx, err) } - fmt.Fprintf(writer, "HTTP/1.1 %d %s\r\nContent-Length: %d\r\n%s\r\n%s", ctx.status, "OK", len(ctx.response.body), ctx.response.headerText(), ctx.response.body) + fmt.Fprintf(writer, "HTTP/1.1 %d\r\nContent-Length: %d\r\n%s\r\n%s", ctx.status, len(ctx.response.body), ctx.response.headerText(), ctx.response.body) } // newContext allocates a new context with the default state. diff --git a/Server_test.go b/Server_test.go index ff942a8..9793aef 100644 --- a/Server_test.go +++ b/Server_test.go @@ -1,6 +1,7 @@ package web_test import ( + "io" "net" "net/http" "syscall" @@ -32,9 +33,122 @@ func TestRun(t *testing.T) { s := web.NewServer() go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + _, err := http.Get("http://127.0.0.1:8080/") assert.Nil(t, err) - err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + }() + + s.Run(":8080") +} + +func TestBadRequest(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", ":8080") + assert.Nil(t, err) + defer conn.Close() + + _, err = io.WriteString(conn, "BadRequest\r\n\r\n") + assert.Nil(t, err) + + response, err := io.ReadAll(conn) + assert.Nil(t, err) + assert.Equal(t, string(response), "HTTP/1.1 400 Bad Request\r\n\r\n") + }() + + s.Run(":8080") +} + +func TestBadRequestHeader(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", ":8080") + assert.Nil(t, err) + defer conn.Close() + + _, err = io.WriteString(conn, "GET / HTTP/1.1\r\nBadHeader\r\nGood: Header\r\n\r\n") + assert.Nil(t, err) + + buffer := make([]byte, len("HTTP/1.1 200")) + _, err = conn.Read(buffer) + assert.Nil(t, err) + assert.Equal(t, string(buffer), "HTTP/1.1 200") + }() + + s.Run(":8080") +} + +func TestBadRequestMethod(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", ":8080") + assert.Nil(t, err) + defer conn.Close() + + _, err = io.WriteString(conn, "BAD-METHOD / HTTP/1.1\r\n\r\n") + assert.Nil(t, err) + + response, err := io.ReadAll(conn) + assert.Nil(t, err) + assert.Equal(t, string(response), "HTTP/1.1 400 Bad Request\r\n\r\n") + }() + + s.Run(":8080") +} + +func TestBadRequestProtocol(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", ":8080") + assert.Nil(t, err) + defer conn.Close() + + _, err = io.WriteString(conn, "GET /\r\n\r\n") + assert.Nil(t, err) + + buffer := make([]byte, len("HTTP/1.1 200")) + _, err = conn.Read(buffer) + assert.Nil(t, err) + assert.Equal(t, string(buffer), "HTTP/1.1 200") + }() + + s.Run(":8080") +} + +func TestEarlyClose(t *testing.T) { + s := web.NewServer() + + go func() { + defer syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + conn, err := net.Dial("tcp", ":8080") + assert.Nil(t, err) + + _, err = io.WriteString(conn, "GET /\r\n") + assert.Nil(t, err) + + err = conn.Close() assert.Nil(t, err) }()