From 3e44adc94f32bfe500730fcbf1c02cedf65b0a30 Mon Sep 17 00:00:00 2001 From: radhitya Date: Sat, 13 Jun 2026 16:09:53 +0700 Subject: root hints, glue record, delegation loop, iterative, ns fallback, timeout, glue record --- internal/server/doh.go | 2 +- internal/server/handler.go | 37 ++++++++++--------- internal/server/server.go | 17 +++++---- internal/server/server_test.go | 83 ++++++++++++++++++------------------------ 4 files changed, 65 insertions(+), 74 deletions(-) (limited to 'internal/server') diff --git a/internal/server/doh.go b/internal/server/doh.go index e9cf466..3f5a538 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } - resp := buildResponse(msg) + resp := s.buildResponse(msg) packed, err := resp.Pack() if err != nil { http.Error(w, "pack response", http.StatusInternalServerError) diff --git a/internal/server/handler.go b/internal/server/handler.go index 2e2f08b..bac1c81 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -1,11 +1,13 @@ package server import ( + "context" "github.com/miekg/dns" "log/slog" + "time" ) -func handleQuery(w dns.ResponseWriter, req *dns.Msg) { +func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { m := new(dns.Msg) m.SetRcode(req, dns.RcodeFormatError) @@ -13,7 +15,7 @@ func handleQuery(w dns.ResponseWriter, req *dns.Msg) { return } - resp := buildResponse(req) + resp := s.buildResponse(req) if err := w.WriteMsg(resp); err != nil { slog.Error("write response failed", @@ -31,7 +33,7 @@ func handleQuery(w dns.ResponseWriter, req *dns.Msg) { ) } -func buildResponse(req *dns.Msg) *dns.Msg { +func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { if len(req.Question) == 0 { m := new(dns.Msg) m.SetRcode(req, dns.RcodeFormatError) @@ -46,20 +48,19 @@ func buildResponse(req *dns.Msg) *dns.Msg { resp.SetEdns0(4096, false) } - if q.Name == "example.com." && q.Qtype == dns.TypeA { - resp.Answer = []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 60, - }, - A: []byte{127, 0, 0, 1}, - }, - } - } else { - resp.Rcode = dns.RcodeNameError + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) + if err != nil { + slog.Error("resolution failed", + "err", err, + "qname", q.Name, + "qtype", dns.TypeToString[q.Qtype], + ) + resp.Rcode = dns.RcodeServerFailure + return resp } - return resp + + return reply } diff --git a/internal/server/server.go b/internal/server/server.go index f40648e..3114073 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,20 +7,21 @@ import ( "time" "github.com/miekg/dns" + "sdns/internal/resolver" ) type Server struct { - logger *slog.Logger - udp *dns.Server - tcp *dns.Server - doh *http.Server + logger *slog.Logger + resolver *resolver.Resolver + udp *dns.Server + tcp *dns.Server + doh *http.Server } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger) (*Server, error) { +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) { + s := &Server{logger: logger, resolver: r} mux := dns.NewServeMux() - mux.HandleFunc(".", handleQuery) - - s := &Server{logger: logger} + mux.HandleFunc(".", s.handleQuery) if udpAddr != "" { s.udp = &dns.Server{ diff --git a/internal/server/server_test.go b/internal/server/server_test.go index eaf0190..6fc2092 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,12 +1,25 @@ package server import ( + "log/slog" "testing" + "time" "github.com/miekg/dns" + "sdns/internal/resolver" ) +func testServer(t *testing.T) *Server { + t.Helper() + r := resolver.New( + resolver.WithRootAddresses([]string{"127.0.0.1:1"}), + resolver.WithTimeout(50*time.Millisecond), + ) + return &Server{logger: slog.Default(), resolver: r} +} + func TestBuildResponse(t *testing.T) { + s := testServer(t) tests := []struct { name string req *dns.Msg @@ -14,39 +27,6 @@ func TestBuildResponse(t *testing.T) { wantAnswers int wantEdns0 bool }{ - { - name: "example.com A returns 127.0.0.1", - req: func() *dns.Msg { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - return m - }(), - wantRcode: dns.RcodeSuccess, - wantAnswers: 1, - wantEdns0: false, - }, - { - name: "google.com A returns NXDOMAIN", - req: func() *dns.Msg { - m := new(dns.Msg) - m.SetQuestion("google.com.", dns.TypeA) - return m - }(), - wantRcode: dns.RcodeNameError, - wantAnswers: 0, - wantEdns0: false, - }, - { - name: "other.com A returns NXDOMAIN", - req: func() *dns.Msg { - m := new(dns.Msg) - m.SetQuestion("other.com.", dns.TypeA) - return m - }(), - wantRcode: dns.RcodeNameError, - wantAnswers: 0, - wantEdns0: false, - }, { name: "no questions returns FORMERR", req: func() *dns.Msg { @@ -56,23 +36,11 @@ func TestBuildResponse(t *testing.T) { wantAnswers: 0, wantEdns0: false, }, - { - name: "EDNS0 query preserved with 4096 buffer", - req: func() *dns.Msg { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetEdns0(1232, true) - return m - }(), - wantRcode: dns.RcodeSuccess, - wantAnswers: 1, - wantEdns0: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := buildResponse(tt.req) + resp := s.buildResponse(tt.req) if resp.Rcode != tt.wantRcode { t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode) } @@ -90,7 +58,28 @@ func TestBuildResponse(t *testing.T) { } } +func TestBuildResponseWithQuery(t *testing.T) { + s := testServer(t) + // Valid query → must not panic, rcode must be valid + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + resp := s.buildResponse(m) + if resp == nil { + t.Fatal("buildResponse returned nil") + } + if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeServerFailure { + t.Errorf("expected success or server failure, got %d", resp.Rcode) + } +} + func FuzzBuildResponse(f *testing.F) { + s := &Server{logger: slog.Default()} + // For fuzz, use a resolver that won't make real network calls + s.resolver = resolver.New( + resolver.WithRootAddresses([]string{"127.0.0.1:1"}), + resolver.WithTimeout(10*time.Millisecond), + ) + seed := []byte{ 0x00, 0x00, // ID 0x01, 0x00, // flags: RD @@ -111,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) { if err := msg.Unpack(data); err != nil { return } - resp := buildResponse(msg) + resp := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } -- cgit v1.2.3