diff options
Diffstat (limited to 'internal/cache')
| -rw-r--r-- | internal/cache/cache.go | 102 | ||||
| -rw-r--r-- | internal/cache/cache_test.go | 27 |
2 files changed, 99 insertions, 30 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 8b13dd1..bb35b8e 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -7,7 +7,7 @@ import ( "sync/atomic" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" _ "modernc.org/sqlite" ) @@ -119,7 +119,7 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) { c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) - msg := e.msg.Copy() + msg := deepCopyMsg(e.msg) remaining := e.remaining() adjustTTL(msg, e.ttl, remaining) return msg, true @@ -134,7 +134,7 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { } e := &entry{ - msg: msg.Copy(), + msg: deepCopyMsg(msg), storedAt: time.Now(), ttl: ttl, } @@ -228,10 +228,11 @@ func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } - data, err := e.msg.Pack() + err := e.msg.Pack() if err != nil { return } + data := e.msg.Data _, err = c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) VALUES (?, ?, ?, ?, ?, ?)`, @@ -269,7 +270,8 @@ func (c *Cache) loadFromDB() { } msg := new(dns.Msg) - if err := msg.Unpack(data); err != nil { + msg.Data = data + if err := msg.Unpack(); err != nil { continue } @@ -300,23 +302,23 @@ func computeTTL(msg *dns.Msg) time.Duration { var min uint32 first := true for _, rr := range msg.Answer { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Ns { - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - if first || rr.Header().Ttl < min { - min = rr.Header().Ttl + if first || rr.Header().TTL < min { + min = rr.Header().TTL first = false } } @@ -334,22 +336,88 @@ func computeTTL(msg *dns.Msg) time.Duration { func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { ratio := float64(remaining) / float64(originalTTL) for _, rr := range msg.Answer { - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Ns { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } for _, rr := range msg.Extra { - if rr.Header().Rrtype == dns.TypeOPT { + if dns.RRToType(rr) == dns.TypeOPT { continue } - rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) } } func applyRatio(ttl uint32, ratio float64) uint32 { return uint32(float64(ttl) * ratio) } + +func deepCopyMsg(msg *dns.Msg) *dns.Msg { + cp := new(dns.Msg) + cp.Response = msg.Response + cp.ID = msg.ID + cp.Opcode = msg.Opcode + cp.Authoritative = msg.Authoritative + cp.Truncated = msg.Truncated + cp.RecursionDesired = msg.RecursionDesired + cp.RecursionAvailable = msg.RecursionAvailable + cp.Rcode = msg.Rcode + cp.UDPSize = msg.UDPSize + cp.Question = copyRRSlice(msg.Question) + cp.Answer = copyRRSlice(msg.Answer) + cp.Ns = copyRRSlice(msg.Ns) + cp.Extra = copyRRSlice(msg.Extra) + return cp +} + +func copyRRSlice(src []dns.RR) []dns.RR { + if src == nil { + return nil + } + dst := make([]dns.RR, len(src)) + for i, rr := range src { + dst[i] = copyRR(rr) + } + return dst +} + +func copyRR(rr dns.RR) dns.RR { + if rr == nil { + return nil + } + switch v := rr.(type) { + case *dns.A: + cp := *v + return &cp + case *dns.AAAA: + cp := *v + return &cp + case *dns.NS: + cp := *v + return &cp + case *dns.CNAME: + cp := *v + return &cp + case *dns.SOA: + cp := *v + return &cp + case *dns.MX: + cp := *v + return &cp + case *dns.TXT: + cp := *v + return &cp + case *dns.SRV: + cp := *v + return &cp + case *dns.PTR: + cp := *v + return &cp + default: + return rr + } +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 6556dcb..33e1d30 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -2,11 +2,12 @@ package cache import ( "fmt" - "net" + "net/netip" "testing" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/rdata" ) func TestSetGet(t *testing.T) { @@ -18,8 +19,8 @@ func TestSetGet(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300}, - A: net.IPv4(1,2,3,4), + Hdr: dns.Header{Name: "example.com.", Class: dns.ClassINET, TTL: 300}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})}, }) key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} @@ -33,8 +34,8 @@ func TestSetGet(t *testing.T) { t.Fatalf("expected 1 answer, got %d", len(got.Answer)) } a, _ := got.Answer[0].(*dns.A) - if !a.A.Equal(net.IPv4(1,2,3,4)) { - t.Errorf("IP = %s, want 1.2.3.4", a.A) + if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { + t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } } @@ -67,7 +68,7 @@ func TestEviction(t *testing.T) { func TestComputeTTL(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120}, + Hdr: dns.Header{Name: "x.", Class: dns.ClassINET, TTL: 120}, }) if d := computeTTL(msg); d != 120*time.Second { t.Errorf("TTL = %v, want 120s", d) @@ -78,8 +79,8 @@ func TestNegativeTTL(t *testing.T) { msg := new(dns.Msg) msg.Rcode = dns.RcodeNameError msg.Ns = append(msg.Ns, &dns.SOA{ - Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900}, - Minttl: 300, + Hdr: dns.Header{Name: "com.", Class: dns.ClassINET, TTL: 900}, + SOA: rdata.SOA{Minttl: 300}, }) if d := computeTTL(msg); d != 300*time.Second { t.Errorf("negative TTL = %v, want 300s", d) @@ -116,8 +117,8 @@ func TestSQLitePersistence(t *testing.T) { msg := new(dns.Msg) msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300}, - A: net.IPv4(1, 2, 3, 4), + Hdr: dns.Header{Name: "x.com.", Class: dns.ClassINET, TTL: 300}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{1, 2, 3, 4})}, }) key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} c.Set(key, msg, 300*time.Second) @@ -134,7 +135,7 @@ func TestSQLitePersistence(t *testing.T) { t.Fatal("expected cache hit from SQLite load") } a, _ := got.Answer[0].(*dns.A) - if !a.A.Equal(net.IPv4(1, 2, 3, 4)) { - t.Errorf("IP = %s, want 1.2.3.4", a.A) + if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { + t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } } |
