summaryrefslogtreecommitdiff
path: root/internal/cache
diff options
context:
space:
mode:
Diffstat (limited to 'internal/cache')
-rw-r--r--internal/cache/cache.go159
-rw-r--r--internal/cache/cache_test.go20
2 files changed, 46 insertions, 133 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index bb35b8e..b0a0959 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -12,18 +12,18 @@ import (
)
type Key struct {
- Name string
+ Name string
Qtype uint16
Class uint16
}
type entry struct {
- msg *dns.Msg
+ packed []byte
storedAt time.Time
- ttl time.Duration
+ ttl time.Duration
}
-func (e *entry) expired() bool {
+func (e *entry) expired() bool {
return time.Since(e.storedAt) >= e.ttl
}
@@ -36,20 +36,20 @@ func (e *entry) remaining() time.Duration {
}
type Cache struct {
- mu sync.RWMutex
+ mu sync.RWMutex
entries map[Key]*entry
maxSize int
- db *sql.DB
- dbCh chan dbWrite
- wg sync.WaitGroup
- hits int64
- misses int64
+ db *sql.DB
+ dbCh chan dbWrite
+ wg sync.WaitGroup
+ hits int64
+ misses int64
evicted int64
- stopCh chan struct{}
+ stopCh chan struct{}
}
type dbWrite struct {
key Key
- e *entry
+ e *entry
}
func NewCache(maxSize int, dbPath string) (*Cache, error) {
@@ -59,7 +59,7 @@ func NewCache(maxSize int, dbPath string) (*Cache, error) {
c := &Cache{
entries: make(map[Key]*entry),
maxSize: maxSize,
- stopCh: make(chan struct{}),
+ stopCh: make(chan struct{}),
}
if dbPath != "" {
@@ -108,7 +108,7 @@ func (c *Cache) Stop() {
}
}
-func (c *Cache) Get(key Key) (*dns.Msg, bool) {
+func (c *Cache) Get(key Key) ([]byte, bool) {
c.mu.RLock()
e, ok := c.entries[key]
if !ok || e.expired() {
@@ -119,10 +119,11 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) {
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
- msg := deepCopyMsg(e.msg)
- remaining := e.remaining()
- adjustTTL(msg, e.ttl, remaining)
- return msg, true
+ cp := make([]byte, len(e.packed))
+ copy(cp, e.packed)
+ cp[0] = 0
+ cp[1] = 0
+ return cp, true
}
func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
@@ -133,10 +134,16 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
ttl = 60 * time.Second
}
+ if err := msg.Pack(); err != nil {
+ return
+ }
+ cp := make([]byte, len(msg.Data))
+ copy(cp, msg.Data)
+
e := &entry{
- msg: deepCopyMsg(msg),
+ packed: cp,
storedAt: time.Now(),
- ttl: ttl,
+ ttl: ttl,
}
c.mu.Lock()
@@ -156,8 +163,8 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
func (c *Cache) Stats() (int64, int64, int64) {
return atomic.LoadInt64(&c.hits),
- atomic.LoadInt64(&c.misses),
- atomic.LoadInt64(&c.evicted)
+ atomic.LoadInt64(&c.misses),
+ atomic.LoadInt64(&c.evicted)
}
func (c *Cache) Len() int {
@@ -228,15 +235,9 @@ func (c *Cache) writeToDB(key Key, e *entry) {
if c.db == nil {
return
}
- err := e.msg.Pack()
- if err != nil {
- return
- }
- data := e.msg.Data
- _, err = c.db.Exec(
+ _, err := c.db.Exec(
`INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
- VALUES (?, ?, ?, ?, ?, ?)`,
- key.Name, key.Qtype, key.Class, data,
+ VALUES (?,?,?,?,?,?)`, key.Name, key.Qtype, key.Class, e.packed,
e.storedAt.UnixNano(), int64(e.ttl),
)
if err != nil {
@@ -269,14 +270,8 @@ func (c *Cache) loadFromDB() {
continue
}
- msg := new(dns.Msg)
- msg.Data = data
- if err := msg.Unpack(); err != nil {
- continue
- }
-
e := &entry{
- msg: msg,
+ packed: data,
storedAt: time.Unix(0, storedAtNano),
ttl: time.Duration(ttlNs),
}
@@ -289,7 +284,6 @@ func (c *Cache) loadFromDB() {
}
}
-
func computeTTL(msg *dns.Msg) time.Duration {
if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 {
for _, rr := range msg.Ns {
@@ -332,92 +326,3 @@ func computeTTL(msg *dns.Msg) time.Duration {
}
return 0
}
-
-func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
- ratio := float64(remaining) / float64(originalTTL)
- for _, rr := range msg.Answer {
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
- for _, rr := range msg.Ns {
- if dns.RRToType(rr) == dns.TypeOPT {
- continue
- }
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
- for _, rr := range msg.Extra {
- if dns.RRToType(rr) == dns.TypeOPT {
- continue
- }
- rr.Header().TTL = applyRatio(rr.Header().TTL, ratio)
- }
-}
-
-func applyRatio(ttl uint32, ratio float64) uint32 {
- return uint32(float64(ttl) * ratio)
-}
-
-func deepCopyMsg(msg *dns.Msg) *dns.Msg {
- cp := new(dns.Msg)
- cp.Response = msg.Response
- cp.ID = msg.ID
- cp.Opcode = msg.Opcode
- cp.Authoritative = msg.Authoritative
- cp.Truncated = msg.Truncated
- cp.RecursionDesired = msg.RecursionDesired
- cp.RecursionAvailable = msg.RecursionAvailable
- cp.Rcode = msg.Rcode
- cp.UDPSize = msg.UDPSize
- cp.Question = copyRRSlice(msg.Question)
- cp.Answer = copyRRSlice(msg.Answer)
- cp.Ns = copyRRSlice(msg.Ns)
- cp.Extra = copyRRSlice(msg.Extra)
- return cp
-}
-
-func copyRRSlice(src []dns.RR) []dns.RR {
- if src == nil {
- return nil
- }
- dst := make([]dns.RR, len(src))
- for i, rr := range src {
- dst[i] = copyRR(rr)
- }
- return dst
-}
-
-func copyRR(rr dns.RR) dns.RR {
- if rr == nil {
- return nil
- }
- switch v := rr.(type) {
- case *dns.A:
- cp := *v
- return &cp
- case *dns.AAAA:
- cp := *v
- return &cp
- case *dns.NS:
- cp := *v
- return &cp
- case *dns.CNAME:
- cp := *v
- return &cp
- case *dns.SOA:
- cp := *v
- return &cp
- case *dns.MX:
- cp := *v
- return &cp
- case *dns.TXT:
- cp := *v
- return &cp
- case *dns.SRV:
- cp := *v
- return &cp
- case *dns.PTR:
- cp := *v
- return &cp
- default:
- return rr
- }
-}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
index 33e1d30..aa0aeb0 100644
--- a/internal/cache/cache_test.go
+++ b/internal/cache/cache_test.go
@@ -26,14 +26,18 @@ func TestSetGet(t *testing.T) {
key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
c.Set(key, msg, 300*time.Second)
- got, ok := c.Get(key)
+ packed, ok := c.Get(key)
if !ok {
t.Fatal("expected cache hit")
}
- if len(got.Answer) != 1 {
- t.Fatalf("expected 1 answer, got %d", len(got.Answer))
+ msg.Data = packed
+ if err := msg.Unpack(); err != nil {
+ t.Fatal(err)
+ }
+ if len(msg.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(msg.Answer))
}
- a, _ := got.Answer[0].(*dns.A)
+ a, _ := msg.Answer[0].(*dns.A)
if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}
@@ -130,11 +134,15 @@ func TestSQLitePersistence(t *testing.T) {
}
defer c2.Stop()
- got, ok := c2.Get(key)
+ packed, ok := c2.Get(key)
if !ok {
t.Fatal("expected cache hit from SQLite load")
}
- a, _ := got.Answer[0].(*dns.A)
+ msg.Data = packed
+ if err := msg.Unpack(); err != nil {
+ t.Fatal(err)
+ }
+ a, _ := msg.Answer[0].(*dns.A)
if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) {
t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr)
}