summaryrefslogtreecommitdiff
path: root/internal/cache/cache.go
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
committerradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
commitf5753c6a8cac5a57a042b0388f38abeff5d1f37d (patch)
tree96e1241126b23051725edb68a79c8e4603d7e23a /internal/cache/cache.go
parente05835493f821055e517a3988c6f9256abbc5c24 (diff)
migration to new dns library
Diffstat (limited to 'internal/cache/cache.go')
-rw-r--r--internal/cache/cache.go102
1 files changed, 85 insertions, 17 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 8b13dd1..bb35b8e 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -7,7 +7,7 @@ import (
"sync/atomic"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
_ "modernc.org/sqlite"
)
@@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) {
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
- msg := e.msg.Copy()
+ msg := deepCopyMsg(e.msg)
remaining := e.remaining()
adjustTTL(msg, e.ttl, remaining)
return msg, true
@@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
}
e := &entry{
- msg: msg.Copy(),
+ msg: deepCopyMsg(msg),
storedAt: time.Now(),
ttl: ttl,
}
@@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) {
if c.db == nil {
return
}
- data, err := e.msg.Pack()
+ err := e.msg.Pack()
if err != nil {
return
}
+ data := e.msg.Data
_, err = c.db.Exec(
`INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
VALUES (?, ?, ?, ?, ?, ?)`,
@@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() {
}
msg := new(dns.Msg)
- if err := msg.Unpack(data); err != nil {
+ msg.Data = data
+ if err := msg.Unpack(); err != nil {
continue
}
@@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration {
var min uint32
first := true
for _, rr := range msg.Answer {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Ns {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
@@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration {
func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
ratio := float64(remaining) / float64(originalTTL)
for _, rr := range msg.Answer {
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Ns {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
}
func applyRatio(ttl uint32, ratio float64) uint32 {
return uint32(float64(ttl) * ratio)
}
+
+func deepCopyMsg(msg *dns.Msg) *dns.Msg {
+ cp := new(dns.Msg)
+ cp.Response = msg.Response
+ cp.ID = msg.ID
+ cp.Opcode = msg.Opcode
+ cp.Authoritative = msg.Authoritative
+ cp.Truncated = msg.Truncated
+ cp.RecursionDesired = msg.RecursionDesired
+ cp.RecursionAvailable = msg.RecursionAvailable
+ cp.Rcode = msg.Rcode
+ cp.UDPSize = msg.UDPSize
+ cp.Question = copyRRSlice(msg.Question)
+ cp.Answer = copyRRSlice(msg.Answer)
+ cp.Ns = copyRRSlice(msg.Ns)
+ cp.Extra = copyRRSlice(msg.Extra)
+ return cp
+}
+
+func copyRRSlice(src []dns.RR) []dns.RR {
+ if src == nil {
+ return nil
+ }
+ dst := make([]dns.RR, len(src))
+ for i, rr := range src {
+ dst[i] = copyRR(rr)
+ }
+ return dst
+}
+
+func copyRR(rr dns.RR) dns.RR {
+ if rr == nil {
+ return nil
+ }
+ switch v := rr.(type) {
+ case *dns.A:
+ cp := *v
+ return &cp
+ case *dns.AAAA:
+ cp := *v
+ return &cp
+ case *dns.NS:
+ cp := *v
+ return &cp
+ case *dns.CNAME:
+ cp := *v
+ return &cp
+ case *dns.SOA:
+ cp := *v
+ return &cp
+ case *dns.MX:
+ cp := *v
+ return &cp
+ case *dns.TXT:
+ cp := *v
+ return &cp
+ case *dns.SRV:
+ cp := *v
+ return &cp
+ case *dns.PTR:
+ cp := *v
+ return &cp
+ default:
+ return rr
+ }
+}