summaryrefslogtreecommitdiff
path: root/internal/cache
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
committerradhitya <alif@radhitya.org>2026-06-18 12:42:29 +0700
commitf5753c6a8cac5a57a042b0388f38abeff5d1f37d (patch)
tree96e1241126b23051725edb68a79c8e4603d7e23a /internal/cache
parente05835493f821055e517a3988c6f9256abbc5c24 (diff)
migration to new dns library
Diffstat (limited to 'internal/cache')
-rw-r--r--internal/cache/cache.go102
-rw-r--r--internal/cache/cache_test.go27
2 files changed, 99 insertions, 30 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 8b13dd1..bb35b8e 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -7,7 +7,7 @@ import (
"sync/atomic"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
_ "modernc.org/sqlite"
)
@@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) {
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
- msg := e.msg.Copy()
+ msg := deepCopyMsg(e.msg)
remaining := e.remaining()
adjustTTL(msg, e.ttl, remaining)
return msg, true
@@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
}
e := &entry{
- msg: msg.Copy(),
+ msg: deepCopyMsg(msg),
storedAt: time.Now(),
ttl: ttl,
}
@@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) {
if c.db == nil {
return
}
- data, err := e.msg.Pack()
+ err := e.msg.Pack()
if err != nil {
return
}
+ data := e.msg.Data
_, err = c.db.Exec(
`INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
VALUES (?, ?, ?, ?, ?, ?)`,
@@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() {
}
msg := new(dns.Msg)
- if err := msg.Unpack(data); err != nil {
+ msg.Data = data
+ if err := msg.Unpack(); err != nil {
continue
}
@@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration {
var min uint32
first := true
for _, rr := range msg.Answer {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Ns {
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- if first || rr.Header().Ttl < min {
- min = rr.Header().Ttl
+ if first || rr.Header().TTL < min {
+ min = rr.Header().TTL
first = false
}
}
@@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration {
func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
ratio := float64(remaining) / float64(originalTTL)
for _, rr := range msg.Answer {
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Ns {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
for _, rr := range msg.Extra {
- if rr.Header().Rrtype == dns.TypeOPT {
+ if dns.RRToType(rr) == dns.TypeOPT {
continue
}
- rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
}
}
func applyRatio(ttl uint32, ratio float64) uint32 {
return uint32(float64(ttl) * ratio)
}
+
+func deepCopyMsg(msg *dns.Msg) *dns.Msg {
+ cp := new(dns.Msg)
+ cp.Response = msg.Response
+ cp.ID = msg.ID
+ cp.Opcode = msg.Opcode
+ cp.Authoritative = msg.Authoritative
+ cp.Truncated = msg.Truncated
+ cp.RecursionDesired = msg.RecursionDesired
+ cp.RecursionAvailable = msg.RecursionAvailable
+ cp.Rcode = msg.Rcode
+ cp.UDPSize = msg.UDPSize
+ cp.Question = copyRRSlice(msg.Question)
+ cp.Answer = copyRRSlice(msg.Answer)
+ cp.Ns = copyRRSlice(msg.Ns)
+ cp.Extra = copyRRSlice(msg.Extra)
+ return cp
+}
+
+func copyRRSlice(src []dns.RR) []dns.RR {
+ if src == nil {
+ return nil
+ }
+ dst := make([]dns.RR, len(src))
+ for i, rr := range src {
+ dst[i] = copyRR(rr)
+ }
+ return dst
+}
+
+func copyRR(rr dns.RR) dns.RR {
+ if rr == nil {
+ return nil
+ }
+ switch v := rr.(type) {
+ case *dns.A:
+ cp := *v
+ return &cp
+ case *dns.AAAA:
+ cp := *v
+ return &cp
+ case *dns.NS:
+ cp := *v
+ return &cp
+ case *dns.CNAME:
+ cp := *v
+ return &cp
+ case *dns.SOA:
+ cp := *v
+ return &cp
+ case *dns.MX:
+ cp := *v
+ return &cp
+ case *dns.TXT:
+ cp := *v
+ return &cp
+ case *dns.SRV:
+ cp := *v
+ return &cp
+ case *dns.PTR:
+ cp := *v
+ return &cp
+ default:
+ return rr
+ }
+}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
index 6556dcb..33e1d30 100644
--- a/internal/cache/cache_test.go
+++ b/internal/cache/cache_test.go
@@ -2,11 +2,12 @@ package cache
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
+ "codeberg.org/miekg/dns/rdata"
)
func TestSetGet(t *testing.T) {
@@ -18,8 +19,8 @@ func TestSetGet(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300},
- A: net.IPv4(1,2,3,4),
+ Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})},
})
key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
@@ -33,8 +34,8 @@ func TestSetGet(t *testing.T) {
t.Fatalf("expected 1 answer, got %d", len(got.Answer))
}
a, _ := got.Answer[0].(*dns.A)
- if !a.A.Equal(net.IPv4(1,2,3,4)) {
- t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
}
@@ -67,7 +68,7 @@ func TestEviction(t *testing.T) {
func TestComputeTTL(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120},
+ Hdr: dns.Header{Name: "x.", Class: dns.ClassINET, TTL: 120},
})
if d := computeTTL(msg); d != 120*time.Second {
t.Errorf("TTL = %v, want 120s", d)
@@ -78,8 +79,8 @@ func TestNegativeTTL(t *testing.T) {
msg := new(dns.Msg)
msg.Rcode = dns.RcodeNameError
msg.Ns = append(msg.Ns, &dns.SOA{
- Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900},
- Minttl: 300,
+ Hdr: dns.Header{Name: "com.", Class: dns.ClassINET, TTL: 900},
+ SOA: rdata.SOA{Minttl: 300},
})
if d := computeTTL(msg); d != 300*time.Second {
t.Errorf("negative TTL = %v, want 300s", d)
@@ -116,8 +117,8 @@ func TestSQLitePersistence(t *testing.T) {
msg := new(dns.Msg)
msg.Answer = append(msg.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300},
- A: net.IPv4(1, 2, 3, 4),
+ Hdr: dns.Header{Name: "x.com.", Class: dns.ClassINET, TTL: 300},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})},
})
key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
c.Set(key, msg, 300*time.Second)
@@ -134,7 +135,7 @@ func TestSQLitePersistence(t *testing.T) {
t.Fatal("expected cache hit from SQLite load")
}
a, _ := got.Answer[0].(*dns.A)
- if !a.A.Equal(net.IPv4(1, 2, 3, 4)) {
- t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
}