From f5753c6a8cac5a57a042b0388f38abeff5d1f37d Mon Sep 17 00:00:00 2001 From: radhitya Date: Thu, 18 Jun 2026 12:42:29 +0700 Subject: migration to new dns library --- internal/cache/cache.go | 102 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 17 deletions(-) (limited to 'internal/cache/cache.go') diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 8b13dd1..bb35b8e 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -7,7 +7,7 @@ import ( "sync/atomic" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" _ "modernc.org/sqlite" ) @@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) { c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) - msg := e.msg.Copy() + msg := deepCopyMsg(e.msg) remaining := e.remaining() adjustTTL(msg, e.ttl, remaining) return msg, true @@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { } e := &entry{ - msg: msg.Copy(), + msg: deepCopyMsg(msg), storedAt: time.Now(), ttl: ttl, } @@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } - data, err := e.msg.Pack() + err := e.msg.Pack() if err != nil { return } + data := e.msg.Data _, err = c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) VALUES (?, ?, ?, ?, ?, ?)`, @@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() { } msg := new(dns.Msg) - if err := msg.Unpack(data); err != nil { + msg.Data = data + if err := msg.Unpack(); err != nil { continue } @@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration { var min uint32 first := true for _, rr := range msg.Answer { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Ns { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } @@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration { func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { ratio := float64(remaining) / float64(originalTTL) for _, rr := range msg.Answer { - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Ns { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } } func applyRatio(ttl uint32, ratio float64) uint32 { return uint32(float64(ttl) * ratio) } + +func deepCopyMsg(msg *dns.Msg) *dns.Msg { + cp := new(dns.Msg) + cp.Response = msg.Response + cp.ID = msg.ID + cp.Opcode = msg.Opcode + cp.Authoritative = msg.Authoritative + cp.Truncated = msg.Truncated + cp.RecursionDesired = msg.RecursionDesired + cp.RecursionAvailable = msg.RecursionAvailable + cp.Rcode = msg.Rcode + cp.UDPSize = msg.UDPSize + cp.Question = copyRRSlice(msg.Question) + cp.Answer = copyRRSlice(msg.Answer) + cp.Ns = copyRRSlice(msg.Ns) + cp.Extra = copyRRSlice(msg.Extra) + return cp +} + +func copyRRSlice(src []dns.RR) []dns.RR { + if src == nil { + return nil + } + dst := make([]dns.RR, len(src)) + for i, rr := range src { + dst[i] = copyRR(rr) + } + return dst +} + +func copyRR(rr dns.RR) dns.RR { + if rr == nil { + return nil + } + switch v := rr.(type) { + case *dns.A: + cp := *v + return &cp + case *dns.AAAA: + cp := *v + return &cp + case *dns.NS: + cp := *v + return &cp + case *dns.CNAME: + cp := *v + return &cp + case *dns.SOA: + cp := *v + return &cp + case *dns.MX: + cp := *v + return &cp + case *dns.TXT: + cp := *v + return &cp + case *dns.SRV: + cp := *v + return &cp + case *dns.PTR: + cp := *v + return &cp + default: + return rr + } +} -- cgit v1.2.3