summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
committerradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
commitf5753c6a8cac5a57a042b0388f38abeff5d1f37d (patch)
tree96e1241126b23051725edb68a79c8e4603d7e23a /internal
parente05835493f821055e517a3988c6f9256abbc5c24 (diff)
migration to new dns library
Diffstat (limited to 'internal')
-rw-r--r--internal/cache/cache.go102
-rw-r--r--internal/cache/cache_test.go27
-rw-r--r--internal/resolver/resolver.go91
-rw-r--r--internal/resolver/resolver_test.go150
-rw-r--r--internal/resolver/root.go4
-rw-r--r--internal/server/doh.go14
-rw-r--r--internal/server/handler.go77
-rw-r--r--internal/server/server.go12
-rw-r--r--internal/server/server_test.go20
9 files changed, 350 insertions, 147 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 8b13dd1..bb35b8e 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -7,7 +7,7 @@ import (
"sync/atomic"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
_ "modernc.org/sqlite"
)
@@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) {
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
- msg := e.msg.Copy()
+ msg := deepCopyMsg(e.msg)
remaining := e.remaining()
adjustTTL(msg, e.ttl, remaining)
return msg, true
@@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
}
e := &entry{
- msg: msg.Copy(),
+ msg: deepCopyMsg(msg),
storedAt: time.Now(),
ttl: ttl,
}
@@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) {
if c.db == nil {
return
}
- data, err := e.msg.Pack()
+ err := e.msg.Pack()
if err != nil {
return
}
+ data := e.msg.Data
_, err = c.db.Exec(
`INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
VALUES (?, ?, ?, ?, ?, ?)`,
@@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() {
}
msg := new(dns.Msg)
- if err := msg.Unpack(data); err != nil {
+ msg.Data = data
+ if err := msg.Unpack(); err != nil {
continue
}
@@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration {
var min uint32
first := true
for _, rr := range msg.Answer {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Ns {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
@@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration {
func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
ratio := float64(remaining) / float64(originalTTL)
for _, rr := range msg.Answer {
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Ns {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
}
func applyRatio(ttl uint32, ratio float64) uint32 {
return uint32(float64(ttl) * ratio)
}
+
+func deepCopyMsg(msg *dns.Msg) *dns.Msg {
+ cp := new(dns.Msg)
+ cp.Response = msg.Response
+ cp.ID = msg.ID
+ cp.Opcode = msg.Opcode
+ cp.Authoritative = msg.Authoritative
+ cp.Truncated = msg.Truncated
+ cp.RecursionDesired = msg.RecursionDesired
+ cp.RecursionAvailable = msg.RecursionAvailable
+ cp.Rcode = msg.Rcode
+ cp.UDPSize = msg.UDPSize
+ cp.Question = copyRRSlice(msg.Question)
+ cp.Answer = copyRRSlice(msg.Answer)
+ cp.Ns = copyRRSlice(msg.Ns)
+ cp.Extra = copyRRSlice(msg.Extra)
+ return cp
+}
+
+func copyRRSlice(src []dns.RR) []dns.RR {
+ if src == nil {
+ return nil
+ }
+ dst := make([]dns.RR, len(src))
+ for i, rr := range src {
+ dst[i] = copyRR(rr)
+ }
+ return dst
+}
+
+func copyRR(rr dns.RR) dns.RR {
+ if rr == nil {
+ return nil
+ }
+ switch v := rr.(type) {
+ case *dns.A:
+ cp := *v
+ return &cp
+ case *dns.AAAA:
+ cp := *v
+ return &cp
+ case *dns.NS:
+ cp := *v
+ return &cp
+ case *dns.CNAME:
+ cp := *v
+ return &cp
+ case *dns.SOA:
+ cp := *v
+ return &cp
+ case *dns.MX:
+ cp := *v
+ return &cp
+ case *dns.TXT:
+ cp := *v
+ return &cp
+ case *dns.SRV:
+ cp := *v
+ return &cp
+ case *dns.PTR:
+ cp := *v
+ return &cp
+ default:
+ return rr
+ }
+}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
index 6556dcb..33e1d30 100644
--- a/internal/cache/cache_test.go
+++ b/internal/cache/cache_test.go
@@ -2,11 +2,12 @@ package cache
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
+ "codeberg.org/miekg/dns/rdata"
)
func TestSetGet(t *testing.T) {
@@ -18,8 +19,8 @@ func TestSetGet(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300},
- A: net.IPv4(1,2,3,4),
+ Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})},
})
key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
@@ -33,8 +34,8 @@ func TestSetGet(t *testing.T) {
t.Fatalf("expected 1 answer, got %d", len(got.Answer))
}
a, _ := got.Answer[0].(*dns.A)
- if !a.A.Equal(net.IPv4(1,2,3,4)) {
- t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
}
@@ -67,7 +68,7 @@ func TestEviction(t *testing.T) {
func TestComputeTTL(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120},
+ Hdr: dns.Header{Name: "x.", Class: dns.ClassINET, TTL: 120},
})
if d := computeTTL(msg); d != 120*time.Second {
t.Errorf("TTL = %v, want 120s", d)
@@ -78,8 +79,8 @@ func TestNegativeTTL(t *testing.T) {
msg := new(dns.Msg)
msg.Rcode = dns.RcodeNameError
msg.Ns = append(msg.Ns, &dns.SOA{
- Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900},
- Minttl: 300,
+ Hdr: dns.Header{Name: "com.", Class: dns.ClassINET, TTL: 900},
+ SOA: rdata.SOA{Minttl: 300},
})
if d := computeTTL(msg); d != 300*time.Second {
t.Errorf("negative TTL = %v, want 300s", d)
@@ -116,8 +117,8 @@ func TestSQLitePersistence(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300},
- A: net.IPv4(1, 2, 3, 4),
+ Hdr: dns.Header{Name: "x.com.", Class: dns.ClassINET, TTL: 300},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})},
})
key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
c.Set(key, msg, 300*time.Second)
@@ -134,7 +135,7 @@ func TestSQLitePersistence(t *testing.T) {
t.Fatal("expected cache hit from SQLite load")
}
a, _ := got.Answer[0].(*dns.A)
- if !a.A.Equal(net.IPv4(1, 2, 3, 4)) {
- t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
}
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 9ee9cd6..5aa7bc1 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -7,7 +7,7 @@ import (
"net"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
)
var (
@@ -26,16 +26,15 @@ type Resolver struct {
type Option func(*Resolver)
func New(opts ...Option) *Resolver {
+ transport := dns.NewTransport()
+ transport.ReadTimeout = 2 * time.Second
+
r := &Resolver{
roots: loadRootServers(),
maxDelegations: 30,
timeout: 2 * time.Second,
retries: 2,
- client: &dns.Client{
- Net: "udp",
- UDPSize: 4096,
- Timeout: 2 * time.Second,
- },
+ client: &dns.Client{Transport: transport},
}
for _, opt := range opts {
opt(r)
@@ -52,6 +51,7 @@ func WithRootAddresses(addrs []string) Option {
func WithTimeout(d time.Duration) Option {
return func(r *Resolver) {
r.timeout = d
+ r.client.Transport.ReadTimeout = d
}
}
@@ -81,6 +81,9 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
}
switch {
case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
+ if needsCNAMEResolution(reply, qtype) {
+ return r.resolveCNAME(ctx, reply, qtype)
+ }
return reply, nil
case reply.Rcode == dns.RcodeNameError:
return reply, nil
@@ -104,6 +107,62 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
return nil, ErrMaxDelegations
}
+func needsCNAMEResolution(reply *dns.Msg, qtype uint16) bool {
+ hasCNAME := false
+ hasTarget := false
+ for _, rr := range reply.Answer {
+ if _, ok := rr.(*dns.CNAME); ok {
+ hasCNAME = true
+ }
+ if dns.RRToType(rr) == qtype && dns.RRToType(rr) != dns.TypeCNAME {
+ hasTarget = true
+ }
+ }
+ return hasCNAME && !hasTarget
+}
+func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint16) (*dns.Msg, error) {
+ var target string
+ for _, rr := range reply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok {
+ target = cn.Target
+ }
+ }
+ if target == "" {
+ return reply, nil
+ }
+
+ const maxChain = 10
+ seen := map[string]bool{reply.Question[0].Header().Name: true}
+ for i := 0; i < maxChain; i++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+ if seen[target] {
+ break
+ }
+ seen[target] = true
+
+ chainReply, err := r.resolve(ctx, target, qtype)
+ if err != nil {
+ return reply, nil
+ }
+ reply.Answer = append(reply.Answer, chainReply.Answer...)
+
+ nextTarget := ""
+ for _, rr := range chainReply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
+ nextTarget = cn.Target
+ }
+ }
+ if nextTarget == "" {
+ break
+ }
+ target = nextTarget
+ }
+ return reply, nil
+}
func isReferral(msg *dns.Msg) bool {
return !msg.Authoritative && len(msg.Ns) > 0
}
@@ -122,12 +181,12 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
for _, rr := range msg.Extra {
switch v := rr.(type) {
case *dns.A:
- if _, exists := glue[v.Hdr.Name]; !exists {
- glue[v.Hdr.Name] = v.A.String()
+ if _, exists := glue[v.Header().Name]; !exists {
+ glue[v.Header().Name] = v.A.Addr.String()
}
case *dns.AAAA:
- if _, exists := glue[v.Hdr.Name]; !exists {
- glue[v.Hdr.Name] = v.AAAA.String()
+ if _, exists := glue[v.Header().Name]; !exists {
+ glue[v.Header().Name] = v.AAAA.Addr.String()
}
}
}
@@ -156,7 +215,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
- addrs = append(addrs, a.A.String())
+ addrs = append(addrs, a.A.Addr.String())
break
}
}
@@ -171,9 +230,8 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
qname string, qtype uint16) (*dns.Msg, error) {
- msg := new(dns.Msg)
- msg.SetQuestion(qname, qtype)
- msg.SetEdns0(4096, false)
+ msg := dns.NewMsg(qname, qtype)
+ msg.UDPSize = 4096
msg.RecursionDesired = false
type result struct {
@@ -189,11 +247,10 @@ func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
}
go func(addr string) {
for attempt := 0; attempt < r.retries; attempt++ {
- reply, _, err := r.client.ExchangeContext(ctx, msg, addr)
+ reply, _, err := r.client.Exchange(ctx, msg, "udp", addr)
if err == nil {
if reply.Truncated {
- tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout}
- reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr)
+ reply, _, err = r.client.Exchange(ctx, msg, "tcp", addr)
if err == nil {
ch <- result{reply: reply, err: nil}
return
diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go
index 0bd0402..54f727e 100644
--- a/internal/resolver/resolver_test.go
+++ b/internal/resolver/resolver_test.go
@@ -2,54 +2,53 @@ package resolver
import (
"context"
- "net"
+ "io"
+ "net/netip"
"testing"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
+ "codeberg.org/miekg/dns/dnsutil"
+ "codeberg.org/miekg/dns/rdata"
)
func startTestServer(t *testing.T, addr string, handler dns.Handler) *dns.Server {
t.Helper()
+ ready := make(chan struct{})
srv := &dns.Server{
Addr: addr,
Net: "udp",
Handler: handler,
UDPSize: 4096,
+ NotifyStartedFunc: func(ctx context.Context) {
+ close(ready)
+ },
}
go func() {
- if err := srv.ListenAndServe(); err != nil {
- }
+ _ = srv.ListenAndServe()
}()
+ <-ready
return srv
}
func TestLookupDirectAnswer(t *testing.T) {
- mux := dns.NewServeMux()
- mux.HandleFunc(".", func(w dns.ResponseWriter, req *dns.Msg) {
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Authoritative = true
- if req.Question[0].Name == "example.com." &&
- req.Question[0].Qtype == dns.TypeA {
+ srv := startTestServer(t, "127.0.0.1:15353",
+ dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ resp := new(dns.Msg)
+ dnsutil.SetReply(resp, req)
+ resp.Authoritative = true
resp.Answer = append(resp.Answer, &dns.A{
- Hdr: dns.RR_Header{
- Name: "example.com.",
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 60,
+ Hdr: dns.Header{
+ Name: "example.com.",
+ Class: dns.ClassINET,
+ TTL: 60,
},
- A: net.IPv4(127, 0, 0, 1),
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{127, 0, 0, 1})},
})
- } else {
- resp.Rcode = dns.RcodeNameError
- }
- w.WriteMsg(resp)
- })
-
- srv := startTestServer(t, "127.0.0.1:15353", mux)
- defer srv.Shutdown()
- time.Sleep(50 * time.Millisecond)
+ io.Copy(w, resp)
+ }))
+ defer srv.Shutdown(context.Background())
+ time.Sleep(100 * time.Millisecond)
r := New(
WithRootAddresses([]string{"127.0.0.1:15353"}),
@@ -73,23 +72,21 @@ func TestLookupDirectAnswer(t *testing.T) {
if !ok {
t.Fatal("expected A record")
}
- if !a.A.Equal(net.IPv4(127, 0, 0, 1)) {
- t.Fatalf("expected 127.0.0.1, got %s", a.A)
+ if a.A.Addr != netip.AddrFrom4([4]byte{127, 0, 0, 1}) {
+ t.Fatalf("expected 127.0.0.1, got %s", a.A.Addr)
}
}
func TestLookupNXDOMAIN(t *testing.T) {
- mux := dns.NewServeMux()
- mux.HandleFunc(".", func(w dns.ResponseWriter, req *dns.Msg) {
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Rcode = dns.RcodeNameError
- w.WriteMsg(resp)
- })
-
- srv := startTestServer(t, "127.0.0.1:15354", mux)
- defer srv.Shutdown()
- time.Sleep(50 * time.Millisecond)
+ srv := startTestServer(t, "127.0.0.1:15354",
+ dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ resp := new(dns.Msg)
+ dnsutil.SetReply(resp, req)
+ resp.Rcode = dns.RcodeNameError
+ io.Copy(w, resp)
+ }))
+ defer srv.Shutdown(context.Background())
+ time.Sleep(100 * time.Millisecond)
r := New(
WithRootAddresses([]string{"127.0.0.1:15354"}),
@@ -111,16 +108,16 @@ func TestNextServersWithGlue(t *testing.T) {
msg := new(dns.Msg)
msg.Authoritative = false
msg.Ns = append(msg.Ns, &dns.NS{
- Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeNS, Ttl: 300},
- Ns: "ns1.example.com.",
+ Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300},
+ NS: rdata.NS{Ns: "ns1.example.com."},
})
msg.Extra = append(msg.Extra, &dns.A{
- Hdr: dns.RR_Header{Name: "ns1.example.com.", Rrtype: dns.TypeA, Ttl: 300},
- A: net.ParseIP("192.0.2.1").To4(),
+ Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300},
+ A: rdata.A{Addr: netip.MustParseAddr("192.0.2.1")},
})
msg.Extra = append(msg.Extra, &dns.AAAA{
- Hdr: dns.RR_Header{Name: "ns1.example.com.", Rrtype: dns.TypeAAAA, Ttl: 300},
- AAAA: net.ParseIP("2001:db8::1"),
+ Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300},
+ AAAA: rdata.AAAA{Addr: netip.MustParseAddr("2001:db8::1")},
})
r := &Resolver{}
@@ -137,8 +134,8 @@ func TestNextServersNoGlue(t *testing.T) {
msg := new(dns.Msg)
msg.Authoritative = false
msg.Ns = append(msg.Ns, &dns.NS{
- Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeNS, Ttl: 300},
- Ns: "ns1.example.com.",
+ Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300},
+ NS: rdata.NS{Ns: "ns1.example.com."},
})
r := &Resolver{maxDelegations: 30, timeout: time.Second, retries: 1}
@@ -147,3 +144,62 @@ func TestNextServersNoGlue(t *testing.T) {
t.Fatal("expected error when no glue and no roots")
}
}
+
+func TestLookupCNAME(t *testing.T) {
+ callCount := 0
+ srv := startTestServer(t, "127.0.0.1:15355",
+ dns.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ callCount++
+ resp := new(dns.Msg)
+ dnsutil.SetReply(resp, req)
+ resp.Authoritative = true
+ resp.Answer = append(resp.Answer,
+ &dns.CNAME{
+ Hdr: dns.Header{Name: "alias.example.com.", Class: dns.ClassINET, TTL: 60},
+ CNAME: rdata.CNAME{Target: "real.example.com."},
+ },
+ )
+ if callCount > 1 {
+ resp.Answer = append(resp.Answer,
+ &dns.A{
+ Hdr: dns.Header{Name: "real.example.com.", Class: dns.ClassINET, TTL: 60},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{127, 0, 0, 1})},
+ },
+ )
+ }
+ io.Copy(w, resp)
+ }))
+ defer srv.Shutdown(context.Background())
+ time.Sleep(100 * time.Millisecond)
+
+ r := New(
+ WithRootAddresses([]string{"127.0.0.1:15355"}),
+ WithTimeout(time.Second),
+ )
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ resp, err := r.Lookup(ctx, "alias.example.com.", dns.TypeA)
+ if err != nil {
+ t.Fatalf("CNAME lookup failed: %v", err)
+ }
+ if resp.Rcode != dns.RcodeSuccess {
+ t.Fatalf("expected NOERROR, got %d", resp.Rcode)
+ }
+ hasCNAME := false
+ hasA := false
+ for _, rr := range resp.Answer {
+ if _, ok := rr.(*dns.CNAME); ok {
+ hasCNAME = true
+ }
+ if _, ok := rr.(*dns.A); ok {
+ hasA = true
+ }
+ }
+ if !hasCNAME {
+ t.Fatal("expected CNAME record in answer")
+ }
+ if !hasA {
+ t.Fatal("expected A record (CNAME target) in answer")
+ }
+}
diff --git a/internal/resolver/root.go b/internal/resolver/root.go
index 9dac31c..0557dac 100644
--- a/internal/resolver/root.go
+++ b/internal/resolver/root.go
@@ -2,7 +2,7 @@ package resolver
import (
_ "embed"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"strings"
)
@@ -23,7 +23,7 @@ func loadRootServers() []string {
if !ok {
continue
}
- ip := a.A.String()
+ ip := a.A.Addr.String()
if !seen[ip] {
seen[ip] = true
addrs = append(addrs, ip)
diff --git a/internal/server/doh.go b/internal/server/doh.go
index b46736e..2f9dfc0 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -2,7 +2,7 @@ package server
import (
"encoding/base64"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"io"
"log/slog"
"net/http"
@@ -42,7 +42,8 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
}
msg := new(dns.Msg)
- if err := msg.Unpack(raw); err != nil {
+ msg.Data = raw
+ if err := msg.Unpack(); err != nil {
http.Error(w, "invalid dns message", http.StatusBadRequest)
return
}
@@ -53,22 +54,21 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
}
resp, _ := s.buildResponse(msg)
- packed, err := resp.Pack()
- if err != nil {
+ if err := resp.Pack(); err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/dns-message")
w.Header().Set("Cache-Control", "no-cache, max-age=0")
- if _, err := w.Write(packed); err != nil {
+ if _, err := w.Write(resp.Data); err != nil {
slog.Error("doh write failed", "err", err)
return
}
slog.Info("doh query served",
- "qname", msg.Question[0].Name,
- "qtype", dns.TypeToString[msg.Question[0].Qtype],
+ "qname", msg.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])],
"rcode", dns.RcodeToString[resp.Rcode],
"client", r.RemoteAddr,
)
diff --git a/internal/server/handler.go b/internal/server/handler.go
index 5f873d4..d0aa705 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -2,36 +2,43 @@ package server
import (
"context"
+ "io"
"log/slog"
- "net"
+ "net/netip"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
+ "codeberg.org/miekg/dns/rdata"
"linum/internal/blocklist"
"linum/internal/cache"
)
-func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
+func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- _ = w.WriteMsg(m)
+ m.Rcode = dns.RcodeFormatError
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
+ io.Copy(w, m)
return
}
resp, blocked := s.buildResponse(req)
- if err := w.WriteMsg(resp); err != nil {
+ resp.ID = req.ID
+ resp.Data = nil
+ if _, err := io.Copy(w, resp); err != nil {
slog.Error("write response failed",
"err", err,
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "qname", req.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
)
return
}
slog.Info("query served",
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "qname", req.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
"rcode", dns.RcodeToString[resp.Rcode],
"blocked", blocked,
)
@@ -39,18 +46,26 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
- return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false
+ m := new(dns.Msg)
+ m.Rcode = dns.RcodeFormatError
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
+ return m, false
}
q := req.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
- if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) {
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
return s.blockedResponse(req), true
}
if s.cache != nil {
- key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
if cached, ok := s.cache.Get(key); ok {
- cached.Id = req.Id
+ cached.ID = req.ID
return cached, false
}
}
@@ -58,22 +73,25 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
- reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
+ reply, err := s.resolver.Lookup(ctx, qname, qtype)
if err != nil {
slog.Error("resolution failed",
"err", err,
- "qname", q.Name,
- "qtype", dns.TypeToString[q.Qtype],
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
)
m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeServerFailure)
+ m.Rcode = dns.RcodeServerFailure
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
return m, false
}
- reply.Id = req.Id
+ reply.ID = req.ID
if s.cache != nil && reply.Rcode != dns.RcodeServerFailure {
- key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
s.cache.Set(key, reply, 0)
}
@@ -82,7 +100,11 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
m := new(dns.Msg)
- m.SetReply(req)
+ m.Response = true
+ m.ID = req.ID
+ m.Opcode = req.Opcode
+ m.RecursionDesired = req.RecursionDesired
+ m.Question = req.Question
m.Authoritative = true
if s.blocklist.Response() == blocklist.ResponseNXDOMAIN {
@@ -91,16 +113,19 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
}
q := req.Question[0]
- switch q.Qtype {
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+
+ switch qtype {
case dns.TypeA:
m.Answer = append(m.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
- A: net.IPv4(0, 0, 0, 0),
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{})},
})
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
- AAAA: net.IPv6zero,
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()},
})
}
return m
diff --git a/internal/server/server.go b/internal/server/server.go
index 8f991eb..1aa3256 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,7 +6,7 @@ import (
"net/http"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"linum/internal/resolver"
"linum/internal/blocklist"
"linum/internal/cache"
@@ -41,7 +41,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
Handler: mux,
UDPSize: 4096,
ReadTimeout: 5 * time.Second,
- WriteTimeout: 5 * time.Second,
ReusePort: true,
}
}
@@ -52,7 +51,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
Net: "tcp",
Handler: mux,
ReadTimeout: 5 * time.Second,
- WriteTimeout: 5 * time.Second,
}
}
@@ -97,10 +95,10 @@ func (s *Server) Run(ctx context.Context) error {
defer cancel()
if s.udp != nil {
- s.udp.ShutdownContext(shutdownCtx)
+ s.udp.Shutdown(shutdownCtx)
}
if s.tcp != nil {
- s.tcp.ShutdownContext(shutdownCtx)
+ s.tcp.Shutdown(shutdownCtx)
}
if s.doh != nil {
s.doh.Shutdown(shutdownCtx)
@@ -113,10 +111,10 @@ func (s *Server) Run(ctx context.Context) error {
func (s *Server) Close() error {
if s.udp != nil {
- s.udp.Shutdown()
+ s.udp.Shutdown(context.Background())
}
if s.tcp != nil {
- s.tcp.Shutdown()
+ s.tcp.Shutdown(context.Background())
}
if s.doh != nil {
s.doh.Close()
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 42938d1..002aba8 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -6,7 +6,7 @@ import (
"testing"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"linum/internal/resolver"
)
@@ -27,7 +27,7 @@ func TestBuildResponse(t *testing.T) {
tests := []struct {
name string
req *dns.Msg
- wantRcode int
+ wantRcode uint16
wantAnswers int
wantEdns0 bool
}{
@@ -52,10 +52,10 @@ func TestBuildResponse(t *testing.T) {
t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers)
}
if tt.wantEdns0 {
- if opt := resp.IsEdns0(); opt == nil {
+ if resp.UDPSize == 0 {
t.Error("expected EDNS0 in response, got none")
- } else if opt.UDPSize() != 4096 {
- t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize())
+ } else if resp.UDPSize != 4096 {
+ t.Errorf("edns0 udp size: got %d, want 4096", resp.UDPSize)
}
}
})
@@ -64,9 +64,7 @@ func TestBuildResponse(t *testing.T) {
func TestBuildResponseWithQuery(t *testing.T) {
s := testServer(t)
- // Valid query → must not panic, rcode must be valid
- m := new(dns.Msg)
- m.SetQuestion("example.com.", dns.TypeA)
+ m := dns.NewMsg("example.com.", dns.TypeA)
resp, _ := s.buildResponse(m)
if resp == nil {
t.Fatal("buildResponse returned nil")
@@ -80,7 +78,6 @@ func FuzzBuildResponse(f *testing.F) {
baseCtx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Server{logger: slog.Default(), baseCtx: baseCtx}
- // For fuzz, use a resolver that won't make real network calls
s.resolver = resolver.New(
resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
resolver.WithTimeout(10*time.Millisecond),
@@ -103,14 +100,15 @@ func FuzzBuildResponse(f *testing.F) {
f.Add(seed)
f.Fuzz(func(t *testing.T, data []byte) {
msg := new(dns.Msg)
- if err := msg.Unpack(data); err != nil {
+ msg.Data = data
+ if err := msg.Unpack(); err != nil {
return
}
resp, _ := s.buildResponse(msg)
if resp == nil {
t.Fatal("buildResponse returned nil")
}
- if _, err := resp.Pack(); err != nil {
+ if err := resp.Pack(); err != nil {
t.Errorf("pack failed: %v", err)
}
})