diff options
| author | radhitya <alif@radhitya.org> | 2026-06-18 12:42:29 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-18 12:42:29 +0700 |
| commit | f5753c6a8cac5a57a042b0388f38abeff5d1f37d (patch) | |
| tree | 96e1241126b23051725edb68a79c8e4603d7e23a | |
| parent | e05835493f821055e517a3988c6f9256abbc5c24 (diff) | |
migration to new dns library
| -rw-r--r-- | internal/cache/cache.go | 102 | ||||
| -rw-r--r-- | internal/cache/cache_test.go | 27 | ||||
| -rw-r--r-- | internal/resolver/resolver.go | 91 | ||||
| -rw-r--r-- | internal/resolver/resolver_test.go | 150 | ||||
| -rw-r--r-- | internal/resolver/root.go | 4 | ||||
| -rw-r--r-- | internal/server/doh.go | 14 | ||||
| -rw-r--r-- | internal/server/handler.go | 77 | ||||
| -rw-r--r-- | internal/server/server.go | 12 | ||||
| -rw-r--r-- | internal/server/server_test.go | 20 |
9 files changed, 350 insertions, 147 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 8b13dd1..bb35b8e 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -7,7 +7,7 @@ import ( "sync/atomic" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" _ "modernc.org/sqlite" ) @@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) { c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) - msg := e.msg.Copy() + msg := deepCopyMsg(e.msg) remaining := e.remaining() adjustTTL(msg, e.ttl, remaining) return msg, true @@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { } e := &entry{ - msg: msg.Copy(), + msg: deepCopyMsg(msg), storedAt: time.Now(), ttl: ttl, } @@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } - data, err := e.msg.Pack() + err := e.msg.Pack() if err != nil { return } + data := e.msg.Data _, err = c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) VALUES (?, ?, ?, ?, ?, ?)`, @@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() { } msg := new(dns.Msg) - if err := msg.Unpack(data); err != nil { + msg.Data = data + if err := msg.Unpack(); err != nil { continue } @@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration { var min uint32 first := true for _, rr := range msg.Answer { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Ns { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } @@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration { func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { ratio := float64(remaining) / float64(originalTTL) for _, rr := range msg.Answer { - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Ns { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } } func applyRatio(ttl uint32, ratio float64) uint32 { return uint32(float64(ttl) * ratio) } + +func deepCopyMsg(msg *dns.Msg) *dns.Msg { + cp := new(dns.Msg) + cp.Response = msg.Response + cp.ID = msg.ID + cp.Opcode = msg.Opcode + cp.Authoritative = msg.Authoritative + cp.Truncated = msg.Truncated + cp.RecursionDesired = msg.RecursionDesired + cp.RecursionAvailable = msg.RecursionAvailable + cp.Rcode = msg.Rcode + cp.UDPSize = msg.UDPSize + cp.Question = copyRRSlice(msg.Question) + cp.Answer = copyRRSlice(msg.Answer) + cp.Ns = copyRRSlice(msg.Ns) + cp.Extra = copyRRSlice(msg.Extra) + return cp +} + +func copyRRSlice(src []dns.RR) []dns.RR { + if src == nil { + return nil + } + dst := make([]dns.RR, len(src)) + for i, rr := range src { + dst[i] = copyRR(rr) + } + return dst +} + +func copyRR(rr dns.RR) dns.RR { + if rr == nil { + return nil + } + switch v := rr.(type) { + case *dns.A: + cp := *v + return &cp + case *dns.AAAA: + cp := *v + return &cp + case *dns.NS: + cp := *v + return &cp + case *dns.CNAME: + cp := *v + return &cp + case *dns.SOA: + cp := *v + return &cp + case *dns.MX: + cp := *v + return &cp + case *dns.TXT: + cp := *v + return &cp + case *dns.SRV: + cp := *v + return &cp + case *dns.PTR: + cp := *v + return &cp + default: + return rr + } +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 6556dcb..33e1d30 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -2,11 +2,12 @@ package cache import ( "fmt" - "net" + "net/netip" "testing" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/rdata" ) func TestSetGet(t *testing.T) { @@ -18,8 +19,8 @@ func TestSetGet(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300}, - A: net.IPv4(1,2,3,4), + Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})}, }) key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} @@ -33,8 +34,8 @@ func TestSetGet(t *testing.T) { t.Fatalf("expected 1 answer, got %d", len(got.Answer)) } a, _ := got.Answer[0].(*dns.A) - if !a.A.Equal(net.IPv4(1,2,3,4)) { - t.Errorf("IP = %s, want 1.2.3.4", a.A) + if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { + t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } } @@ -67,7 +68,7 @@ func TestEviction(t *testing.T) { func TestComputeTTL(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120}, + Hdr: dns.Header{Name: "x.", Class: dns.ClassINET, TTL: 120}, }) if d := computeTTL(msg); d != 120*time.Second { t.Errorf("TTL = %v, want 120s", d) @@ -78,8 +79,8 @@ func TestNegativeTTL(t *testing.T) { msg := new(dns.Msg) msg.Rcode = dns.RcodeNameError msg.Ns = append(msg.Ns, &dns.SOA{ - Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900}, - Minttl: 300, + Hdr: dns.Header{Name: "com.", Class: dns.ClassINET, TTL: 900}, + SOA: rdata.SOA{Minttl: 300}, }) if d := computeTTL(msg); d != 300*time.Second { t.Errorf("negative TTL = %v, want 300s", d) @@ -116,8 +117,8 @@ func TestSQLitePersistence(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300}, - A: net.IPv4(1, 2, 3, 4), + Hdr: dns.Header{Name: "x.com.", Class: dns.ClassINET, TTL: 300}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})}, }) key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} c.Set(key, msg, 300*time.Second) @@ -134,7 +135,7 @@ func TestSQLitePersistence(t *testing.T) { t.Fatal("expected cache hit from SQLite load") } a, _ := got.Answer[0].(*dns.A) - if !a.A.Equal(net.IPv4(1, 2, 3, 4)) { - t.Errorf("IP = %s, want 1.2.3.4", a.A) + if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { + t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 9ee9cd6..5aa7bc1 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -7,7 +7,7 @@ import ( "net" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" ) var ( @@ -26,16 +26,15 @@ type Resolver struct { type Option func(*Resolver) func New(opts ...Option) *Resolver { + transport := dns.NewTransport() + transport.ReadTimeout = 2 * time.Second + r := &Resolver{ roots: loadRootServers(), maxDelegations: 30, timeout: 2 * time.Second, retries: 2, - client: &dns.Client{ - Net: "udp", - UDPSize: 4096, - Timeout: 2 * time.Second, - }, + client: &dns.Client{Transport: transport}, } for _, opt := range opts { opt(r) @@ -52,6 +51,7 @@ func WithRootAddresses(addrs []string) Option { func WithTimeout(d time.Duration) Option { return func(r *Resolver) { r.timeout = d + r.client.Transport.ReadTimeout = d } } @@ -81,6 +81,9 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: + if needsCNAMEResolution(reply, qtype) { + return r.resolveCNAME(ctx, reply, qtype) + } return reply, nil case reply.Rcode == dns.RcodeNameError: return reply, nil @@ -104,6 +107,62 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn return nil, ErrMaxDelegations } +func needsCNAMEResolution(reply *dns.Msg, qtype uint16) bool { + hasCNAME := false + hasTarget := false + for _, rr := range reply.Answer { + if _, ok := rr.(*dns.CNAME); ok { + hasCNAME = true + } + if dns.RRToType(rr) == qtype && dns.RRToType(rr) != dns.TypeCNAME { + hasTarget = true + } + } + return hasCNAME && !hasTarget +} +func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint16) (*dns.Msg, error) { + var target string + for _, rr := range reply.Answer { + if cn, ok := rr.(*dns.CNAME); ok { + target = cn.Target + } + } + if target == "" { + return reply, nil + } + + const maxChain = 10 + seen := map[string]bool{reply.Question[0].Header().Name: true} + for i := 0; i < maxChain; i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if seen[target] { + break + } + seen[target] = true + + chainReply, err := r.resolve(ctx, target, qtype) + if err != nil { + return reply, nil + } + reply.Answer = append(reply.Answer, chainReply.Answer...) + + nextTarget := "" + for _, rr := range chainReply.Answer { + if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target { + nextTarget = cn.Target + } + } + if nextTarget == "" { + break + } + target = nextTarget + } + return reply, nil +} func isReferral(msg *dns.Msg) bool { return !msg.Authoritative && len(msg.Ns) > 0 } @@ -122,12 +181,12 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err for _, rr := range msg.Extra { switch v := rr.(type) { case *dns.A: - if _, exists := glue[v.Hdr.Name]; !exists { - glue[v.Hdr.Name] = v.A.String() + if _, exists := glue[v.Header().Name]; !exists { + glue[v.Header().Name] = v.A.Addr.String() } case *dns.AAAA: - if _, exists := glue[v.Hdr.Name]; !exists { - glue[v.Hdr.Name] = v.AAAA.String() + if _, exists := glue[v.Header().Name]; !exists { + glue[v.Header().Name] = v.AAAA.Addr.String() } } } @@ -156,7 +215,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err } for _, rr := range reply.Answer { if a, ok := rr.(*dns.A); ok { - addrs = append(addrs, a.A.String()) + addrs = append(addrs, a.A.Addr.String()) break } } @@ -171,9 +230,8 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, qname string, qtype uint16) (*dns.Msg, error) { - msg := new(dns.Msg) - msg.SetQuestion(qname, qtype) - msg.SetEdns0(4096, false) + msg := dns.NewMsg(qname, qtype) + msg.UDPSize = 4096 msg.RecursionDesired = false type result struct { @@ -189,11 +247,10 @@ func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, } go func(addr string) { for attempt := 0; attempt < r.retries; attempt++ { - reply, _, err := r.client.ExchangeContext(ctx, msg, addr) + reply, _, err := r.client.Exchange(ctx, msg, "udp", addr) if err == nil { if reply.Truncated { - tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout} - reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr) + reply, _, err = r.client.Exchange(ctx, msg, "tcp", addr) if err == nil { ch <- result{reply: reply, err: nil} return diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 0bd0402..54f727e 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -2,54 +2,53 @@ package resolver import ( "context" - "net" + "io" + "net/netip" "testing" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/dnsutil" + "codeberg.org/miekg/dns/rdata" ) func startTestServer(t *testing.T, addr string, handler dns.Handler) *dns.Server { t.Helper() + ready := make(chan struct{}) srv := &dns.Server{ Addr: addr, Net: "udp", Handler: handler, UDPSize: 4096, + NotifyStartedFunc: func(ctx context.Context) { + close(ready) + }, } go func() { - if err := srv.ListenAndServe(); err != nil { - } + _ = srv.ListenAndServe() }() + <-ready return srv } func TestLookupDirectAnswer(t *testing.T) { - mux := dns.NewServeMux() - mux.HandleFunc(".", func(w dns.ResponseWriter, req *dns.Msg) { - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = true - if req.Question[0].Name == "example.com." && - req.Question[0].Qtype == dns.TypeA { + srv := startTestServer(t, "127.0.0.1:15353", + dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + resp := new(dns.Msg) + dnsutil.SetReply(resp, req) + resp.Authoritative = true resp.Answer = append(resp.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: "example.com.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 60, + Hdr: dns.Header{ + Name: "example.com.", + Class: dns.ClassINET, + TTL: 60, }, - A: net.IPv4(127, 0, 0, 1), + A: rdata.A{Addr: netip.AddrFrom4([4]byte{127, 0, 0, 1})}, }) - } else { - resp.Rcode = dns.RcodeNameError - } - w.WriteMsg(resp) - }) - - srv := startTestServer(t, "127.0.0.1:15353", mux) - defer srv.Shutdown() - time.Sleep(50 * time.Millisecond) + io.Copy(w, resp) + })) + defer srv.Shutdown(context.Background()) + time.Sleep(100 * time.Millisecond) r := New( WithRootAddresses([]string{"127.0.0.1:15353"}), @@ -73,23 +72,21 @@ func TestLookupDirectAnswer(t *testing.T) { if !ok { t.Fatal("expected A record") } - if !a.A.Equal(net.IPv4(127, 0, 0, 1)) { - t.Fatalf("expected 127.0.0.1, got %s", a.A) + if a.A.Addr != netip.AddrFrom4([4]byte{127, 0, 0, 1}) { + t.Fatalf("expected 127.0.0.1, got %s", a.A.Addr) } } func TestLookupNXDOMAIN(t *testing.T) { - mux := dns.NewServeMux() - mux.HandleFunc(".", func(w dns.ResponseWriter, req *dns.Msg) { - resp := new(dns.Msg) - resp.SetReply(req) - resp.Rcode = dns.RcodeNameError - w.WriteMsg(resp) - }) - - srv := startTestServer(t, "127.0.0.1:15354", mux) - defer srv.Shutdown() - time.Sleep(50 * time.Millisecond) + srv := startTestServer(t, "127.0.0.1:15354", + dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + resp := new(dns.Msg) + dnsutil.SetReply(resp, req) + resp.Rcode = dns.RcodeNameError + io.Copy(w, resp) + })) + defer srv.Shutdown(context.Background()) + time.Sleep(100 * time.Millisecond) r := New( WithRootAddresses([]string{"127.0.0.1:15354"}), @@ -111,16 +108,16 @@ func TestNextServersWithGlue(t *testing.T) { msg := new(dns.Msg) msg.Authoritative = false msg.Ns = append(msg.Ns, &dns.NS{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeNS, Ttl: 300}, - Ns: "ns1.example.com.", + Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300}, + NS: rdata.NS{Ns: "ns1.example.com."}, }) msg.Extra = append(msg.Extra, &dns.A{ - Hdr: dns.RR_Header{Name: "ns1.example.com.", Rrtype: dns.TypeA, Ttl: 300}, - A: net.ParseIP("192.0.2.1").To4(), + Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300}, + A: rdata.A{Addr: netip.MustParseAddr("192.0.2.1")}, }) msg.Extra = append(msg.Extra, &dns.AAAA{ - Hdr: dns.RR_Header{Name: "ns1.example.com.", Rrtype: dns.TypeAAAA, Ttl: 300}, - AAAA: net.ParseIP("2001:db8::1"), + Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300}, + AAAA: rdata.AAAA{Addr: netip.MustParseAddr("2001:db8::1")}, }) r := &Resolver{} @@ -137,8 +134,8 @@ func TestNextServersNoGlue(t *testing.T) { msg := new(dns.Msg) msg.Authoritative = false msg.Ns = append(msg.Ns, &dns.NS{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeNS, Ttl: 300}, - Ns: "ns1.example.com.", + Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300}, + NS: rdata.NS{Ns: "ns1.example.com."}, }) r := &Resolver{maxDelegations: 30, timeout: time.Second, retries: 1} @@ -147,3 +144,62 @@ func TestNextServersNoGlue(t *testing.T) { t.Fatal("expected error when no glue and no roots") } } + +func TestLookupCNAME(t *testing.T) { + callCount := 0 + srv := startTestServer(t, "127.0.0.1:15355", + dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + callCount++ + resp := new(dns.Msg) + dnsutil.SetReply(resp, req) + resp.Authoritative = true + resp.Answer = append(resp.Answer, + &dns.CNAME{ + Hdr: dns.Header{Name: "alias.example.com.", Class: dns.ClassINET, TTL: 60}, + CNAME: rdata.CNAME{Target: "real.example.com."}, + }, + ) + if callCount > 1 { + resp.Answer = append(resp.Answer, + &dns.A{ + Hdr: dns.Header{Name: "real.example.com.", Class: dns.ClassINET, TTL: 60}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{127, 0, 0, 1})}, + }, + ) + } + io.Copy(w, resp) + })) + defer srv.Shutdown(context.Background()) + time.Sleep(100 * time.Millisecond) + + r := New( + WithRootAddresses([]string{"127.0.0.1:15355"}), + WithTimeout(time.Second), + ) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + resp, err := r.Lookup(ctx, "alias.example.com.", dns.TypeA) + if err != nil { + t.Fatalf("CNAME lookup failed: %v", err) + } + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("expected NOERROR, got %d", resp.Rcode) + } + hasCNAME := false + hasA := false + for _, rr := range resp.Answer { + if _, ok := rr.(*dns.CNAME); ok { + hasCNAME = true + } + if _, ok := rr.(*dns.A); ok { + hasA = true + } + } + if !hasCNAME { + t.Fatal("expected CNAME record in answer") + } + if !hasA { + t.Fatal("expected A record (CNAME target) in answer") + } +} diff --git a/internal/resolver/root.go b/internal/resolver/root.go index 9dac31c..0557dac 100644 --- a/internal/resolver/root.go +++ b/internal/resolver/root.go @@ -2,7 +2,7 @@ package resolver import ( _ "embed" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "strings" ) @@ -23,7 +23,7 @@ func loadRootServers() []string { if !ok { continue } - ip := a.A.String() + ip := a.A.Addr.String() if !seen[ip] { seen[ip] = true addrs = append(addrs, ip) diff --git a/internal/server/doh.go b/internal/server/doh.go index b46736e..2f9dfc0 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -2,7 +2,7 @@ package server import ( "encoding/base64" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "io" "log/slog" "net/http" @@ -42,7 +42,8 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { } msg := new(dns.Msg) - if err := msg.Unpack(raw); err != nil { + msg.Data = raw + if err := msg.Unpack(); err != nil { http.Error(w, "invalid dns message", http.StatusBadRequest) return } @@ -53,22 +54,21 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { } resp, _ := s.buildResponse(msg) - packed, err := resp.Pack() - if err != nil { + if err := resp.Pack(); err != nil { http.Error(w, "pack response", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") - if _, err := w.Write(packed); err != nil { + if _, err := w.Write(resp.Data); err != nil { slog.Error("doh write failed", "err", err) return } slog.Info("doh query served", - "qname", msg.Question[0].Name, - "qtype", dns.TypeToString[msg.Question[0].Qtype], + "qname", msg.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])], "rcode", dns.RcodeToString[resp.Rcode], "client", r.RemoteAddr, ) diff --git a/internal/server/handler.go b/internal/server/handler.go index 5f873d4..d0aa705 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,36 +2,43 @@ package server import ( "context" + "io" "log/slog" - "net" + "net/netip" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/rdata" "linum/internal/blocklist" "linum/internal/cache" ) -func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { +func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - _ = w.WriteMsg(m) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + io.Copy(w, m) return } resp, blocked := s.buildResponse(req) - if err := w.WriteMsg(resp); err != nil { + resp.ID = req.ID + resp.Data = nil + if _, err := io.Copy(w, resp); err != nil { slog.Error("write response failed", "err", err, - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked, ) @@ -39,18 +46,26 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false + m := new(dns.Msg) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + return m, false } q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { return s.blockedResponse(req), true } if s.cache != nil { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} if cached, ok := s.cache.Get(key); ok { - cached.Id = req.Id + cached.ID = req.ID return cached, false } } @@ -58,22 +73,25 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() - reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) + reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", "err", err, - "qname", q.Name, - "qtype", dns.TypeToString[q.Qtype], + "qname", qname, + "qtype", dns.TypeToString[qtype], ) m := new(dns.Msg) - m.SetRcode(req, dns.RcodeServerFailure) + m.Rcode = dns.RcodeServerFailure + m.Response = true + m.ID = req.ID + m.Question = req.Question return m, false } - reply.Id = req.Id + reply.ID = req.ID if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} s.cache.Set(key, reply, 0) } @@ -82,7 +100,11 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { m := new(dns.Msg) - m.SetReply(req) + m.Response = true + m.ID = req.ID + m.Opcode = req.Opcode + m.RecursionDesired = req.RecursionDesired + m.Question = req.Question m.Authoritative = true if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { @@ -91,16 +113,19 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { } q := req.Question[0] - switch q.Qtype { + qname := q.Header().Name + qtype := dns.RRToType(q) + + switch qtype { case dns.TypeA: m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.IPv4(0, 0, 0, 0), + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{})}, }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, - AAAA: net.IPv6zero, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m diff --git a/internal/server/server.go b/internal/server/server.go index 8f991eb..1aa3256 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "linum/internal/resolver" "linum/internal/blocklist" "linum/internal/cache" @@ -41,7 +41,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { Handler: mux, UDPSize: 4096, ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, ReusePort: true, } } @@ -52,7 +51,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { Net: "tcp", Handler: mux, ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, } } @@ -97,10 +95,10 @@ func (s *Server) Run(ctx context.Context) error { defer cancel() if s.udp != nil { - s.udp.ShutdownContext(shutdownCtx) + s.udp.Shutdown(shutdownCtx) } if s.tcp != nil { - s.tcp.ShutdownContext(shutdownCtx) + s.tcp.Shutdown(shutdownCtx) } if s.doh != nil { s.doh.Shutdown(shutdownCtx) @@ -113,10 +111,10 @@ func (s *Server) Run(ctx context.Context) error { func (s *Server) Close() error { if s.udp != nil { - s.udp.Shutdown() + s.udp.Shutdown(context.Background()) } if s.tcp != nil { - s.tcp.Shutdown() + s.tcp.Shutdown(context.Background()) } if s.doh != nil { s.doh.Close() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 42938d1..002aba8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" "linum/internal/resolver" ) @@ -27,7 +27,7 @@ func TestBuildResponse(t *testing.T) { tests := []struct { name string req *dns.Msg - wantRcode int + wantRcode uint16 wantAnswers int wantEdns0 bool }{ @@ -52,10 +52,10 @@ func TestBuildResponse(t *testing.T) { t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers) } if tt.wantEdns0 { - if opt := resp.IsEdns0(); opt == nil { + if resp.UDPSize == 0 { t.Error("expected EDNS0 in response, got none") - } else if opt.UDPSize() != 4096 { - t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize()) + } else if resp.UDPSize != 4096 { + t.Errorf("edns0 udp size: got %d, want 4096", resp.UDPSize) } } }) @@ -64,9 +64,7 @@ func TestBuildResponse(t *testing.T) { func TestBuildResponseWithQuery(t *testing.T) { s := testServer(t) - // Valid query → must not panic, rcode must be valid - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) + m := dns.NewMsg("example.com.", dns.TypeA) resp, _ := s.buildResponse(m) if resp == nil { t.Fatal("buildResponse returned nil") @@ -80,7 +78,6 @@ func FuzzBuildResponse(f *testing.F) { baseCtx, cancel := context.WithCancel(context.Background()) defer cancel() s := &Server{logger: slog.Default(), baseCtx: baseCtx} - // For fuzz, use a resolver that won't make real network calls s.resolver = resolver.New( resolver.WithRootAddresses([]string{"127.0.0.1:1"}), resolver.WithTimeout(10*time.Millisecond), @@ -103,14 +100,15 @@ func FuzzBuildResponse(f *testing.F) { f.Add(seed) f.Fuzz(func(t *testing.T, data []byte) { msg := new(dns.Msg) - if err := msg.Unpack(data); err != nil { + msg.Data = data + if err := msg.Unpack(); err != nil { return } resp, _ := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } - if _, err := resp.Pack(); err != nil { + if err := resp.Pack(); err != nil { t.Errorf("pack failed: %v", err) } }) |
