summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handler.go2
-rw-r--r--internal/server/server.go8
-rw-r--r--internal/server/server_test.go10
3 files changed, 16 insertions, 4 deletions
diff --git a/internal/server/handler.go b/internal/server/handler.go
index 406b7ed..e996270 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -55,7 +55,7 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
return s.blockedResponse(req), true
}
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
diff --git a/internal/server/server.go b/internal/server/server.go
index ec0dec9..8f991eb 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -20,11 +20,17 @@ type Server struct {
udp *dns.Server
tcp *dns.Server
doh *http.Server
+ baseCtx context.Context
+ cancel context.CancelFunc
}
func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
- s := &Server{logger: logger, resolver: r, cache: c, blocklist: b}
+ baseCtx, cancel := context.WithCancel(context.Background())
+
+ s := &Server{logger: logger, resolver: r, cache: c, blocklist: b,
+ baseCtx: baseCtx, cancel: cancel}
+
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index e59a131..42938d1 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -2,6 +2,7 @@ package server
import (
"log/slog"
+ "context"
"testing"
"time"
@@ -15,7 +16,10 @@ func testServer(t *testing.T) *Server {
resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
resolver.WithTimeout(50*time.Millisecond),
)
- return &Server{logger: slog.Default(), resolver: r}
+ baseCtx, cancel := context.WithCancel(context.Background())
+
+ t.Cleanup(cancel)
+ return &Server{logger: slog.Default(), resolver: r, baseCtx: baseCtx}
}
func TestBuildResponse(t *testing.T) {
@@ -73,7 +77,9 @@ func TestBuildResponseWithQuery(t *testing.T) {
}
func FuzzBuildResponse(f *testing.F) {
- s := &Server{logger: slog.Default()}
+ baseCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ s := &Server{logger: slog.Default(), baseCtx: baseCtx}
// For fuzz, use a resolver that won't make real network calls
s.resolver = resolver.New(
resolver.WithRootAddresses([]string{"127.0.0.1:1"}),