From f5753c6a8cac5a57a042b0388f38abeff5d1f37d Mon Sep 17 00:00:00 2001 From: radhitya Date: Thu, 18 Jun 2026 12:42:29 +0700 Subject: migration to new dns library --- internal/server/doh.go | 14 ++++---- internal/server/handler.go | 77 ++++++++++++++++++++++++++++-------------- internal/server/server.go | 12 +++---- internal/server/server_test.go | 20 +++++------ 4 files changed, 72 insertions(+), 51 deletions(-) (limited to 'internal/server') diff --git a/internal/server/doh.go b/internal/server/doh.go index b46736e..2f9dfc0 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -2,7 +2,7 @@ package server import ( "encoding/base64" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "io" "log/slog" "net/http" @@ -42,7 +42,8 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { } msg := new(dns.Msg) - if err := msg.Unpack(raw); err != nil { + msg.Data = raw + if err := msg.Unpack(); err != nil { http.Error(w, "invalid dns message", http.StatusBadRequest) return } @@ -53,22 +54,21 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { } resp, _ := s.buildResponse(msg) - packed, err := resp.Pack() - if err != nil { + if err := resp.Pack(); err != nil { http.Error(w, "pack response", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") - if _, err := w.Write(packed); err != nil { + if _, err := w.Write(resp.Data); err != nil { slog.Error("doh write failed", "err", err) return } slog.Info("doh query served", - "qname", msg.Question[0].Name, - "qtype", dns.TypeToString[msg.Question[0].Qtype], + "qname", msg.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])], "rcode", dns.RcodeToString[resp.Rcode], "client", r.RemoteAddr, ) diff --git a/internal/server/handler.go b/internal/server/handler.go index 5f873d4..d0aa705 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,36 +2,43 @@ package server import ( "context" + "io" "log/slog" - "net" + "net/netip" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/rdata" "linum/internal/blocklist" "linum/internal/cache" ) -func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { +func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - _ = w.WriteMsg(m) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + io.Copy(w, m) return } resp, blocked := s.buildResponse(req) - if err := w.WriteMsg(resp); err != nil { + resp.ID = req.ID + resp.Data = nil + if _, err := io.Copy(w, resp); err != nil { slog.Error("write response failed", "err", err, - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked, ) @@ -39,18 +46,26 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false + m := new(dns.Msg) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + return m, false } q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { return s.blockedResponse(req), true } if s.cache != nil { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} if cached, ok := s.cache.Get(key); ok { - cached.Id = req.Id + cached.ID = req.ID return cached, false } } @@ -58,22 +73,25 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() - reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) + reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", "err", err, - "qname", q.Name, - "qtype", dns.TypeToString[q.Qtype], + "qname", qname, + "qtype", dns.TypeToString[qtype], ) m := new(dns.Msg) - m.SetRcode(req, dns.RcodeServerFailure) + m.Rcode = dns.RcodeServerFailure + m.Response = true + m.ID = req.ID + m.Question = req.Question return m, false } - reply.Id = req.Id + reply.ID = req.ID if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} s.cache.Set(key, reply, 0) } @@ -82,7 +100,11 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { m := new(dns.Msg) - m.SetReply(req) + m.Response = true + m.ID = req.ID + m.Opcode = req.Opcode + m.RecursionDesired = req.RecursionDesired + m.Question = req.Question m.Authoritative = true if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { @@ -91,16 +113,19 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { } q := req.Question[0] - switch q.Qtype { + qname := q.Header().Name + qtype := dns.RRToType(q) + + switch qtype { case dns.TypeA: m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.IPv4(0, 0, 0, 0), + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{})}, }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, - AAAA: net.IPv6zero, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m diff --git a/internal/server/server.go b/internal/server/server.go index 8f991eb..1aa3256 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "linum/internal/resolver" "linum/internal/blocklist" "linum/internal/cache" @@ -41,7 +41,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { Handler: mux, UDPSize: 4096, ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, ReusePort: true, } } @@ -52,7 +51,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { Net: "tcp", Handler: mux, ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, } } @@ -97,10 +95,10 @@ func (s *Server) Run(ctx context.Context) error { defer cancel() if s.udp != nil { - s.udp.ShutdownContext(shutdownCtx) + s.udp.Shutdown(shutdownCtx) } if s.tcp != nil { - s.tcp.ShutdownContext(shutdownCtx) + s.tcp.Shutdown(shutdownCtx) } if s.doh != nil { s.doh.Shutdown(shutdownCtx) @@ -113,10 +111,10 @@ func (s *Server) Run(ctx context.Context) error { func (s *Server) Close() error { if s.udp != nil { - s.udp.Shutdown() + s.udp.Shutdown(context.Background()) } if s.tcp != nil { - s.tcp.Shutdown() + s.tcp.Shutdown(context.Background()) } if s.doh != nil { s.doh.Close() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 42938d1..002aba8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "linum/internal/resolver" ) @@ -27,7 +27,7 @@ func TestBuildResponse(t *testing.T) { tests := []struct { name string req *dns.Msg - wantRcode int + wantRcode uint16 wantAnswers int wantEdns0 bool }{ @@ -52,10 +52,10 @@ func TestBuildResponse(t *testing.T) { t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers) } if tt.wantEdns0 { - if opt := resp.IsEdns0(); opt == nil { + if resp.UDPSize == 0 { t.Error("expected EDNS0 in response, got none") - } else if opt.UDPSize() != 4096 { - t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize()) + } else if resp.UDPSize != 4096 { + t.Errorf("edns0 udp size: got %d, want 4096", resp.UDPSize) } } }) @@ -64,9 +64,7 @@ func TestBuildResponse(t *testing.T) { func TestBuildResponseWithQuery(t *testing.T) { s := testServer(t) - // Valid query → must not panic, rcode must be valid - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) + m := dns.NewMsg("example.com.", dns.TypeA) resp, _ := s.buildResponse(m) if resp == nil { t.Fatal("buildResponse returned nil") @@ -80,7 +78,6 @@ func FuzzBuildResponse(f *testing.F) { baseCtx, cancel := context.WithCancel(context.Background()) defer cancel() s := &Server{logger: slog.Default(), baseCtx: baseCtx} - // For fuzz, use a resolver that won't make real network calls s.resolver = resolver.New( resolver.WithRootAddresses([]string{"127.0.0.1:1"}), resolver.WithTimeout(10*time.Millisecond), @@ -103,14 +100,15 @@ func FuzzBuildResponse(f *testing.F) { f.Add(seed) f.Fuzz(func(t *testing.T, data []byte) { msg := new(dns.Msg) - if err := msg.Unpack(data); err != nil { + msg.Data = data + if err := msg.Unpack(); err != nil { return } resp, _ := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } - if _, err := resp.Pack(); err != nil { + if err := resp.Pack(); err != nil { t.Errorf("pack failed: %v", err) } }) -- cgit v1.2.3