diff options
Diffstat (limited to 'internal/cache/cache.go')
| -rw-r--r-- | internal/cache/cache.go | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a2d86a0 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,317 @@ +package cache + +import ( + "database/sql" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + _ "modernc.org/sqlite" +) + +type Key struct { + Name string + Qtype uint16 + Class uint16 +} + +type entry struct { + msg *dns.Msg + storedAt time.Time + ttl time.Duration +} + +func (e *entry) expired() bool { + return time.Since(e.storedAt) >= e.ttl +} + +func (e *entry) remaining() time.Duration { + d := e.ttl - time.Since(e.storedAt) + if d < 0 { + return 0 + } + return d +} + +type Cache struct { + mu sync.RWMutex + entries map[Key]*entry + maxSize int + db *sql.DB + + hits int64 + misses int64 + evicted int64 + stopCh chan struct{} +} + +func NewCache(maxSize int, dbPath string) (*Cache, error) { + if maxSize <= 0 { + maxSize = 100000 + } + c := &Cache{ + entries: make(map[Key]*entry), + maxSize: maxSize, + stopCh: make(chan struct{}), + } + + if dbPath != "" { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + c.db = db + + db.Exec("PRAGMA journal_mode=WAL") + db.Exec("PRAGMA synchronous=NORMAL") + db.Exec("PRAGMA cache_size=-65536") + + if _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS cache ( + name TEXT NOT NULL, + qtype INTEGER NOT NULL, + class INTEGER NOT NULL, + data BLOB NOT NULL, + stored_at INTEGER NOT NULL, + ttl_ns INTEGER NOT NULL, + PRIMARY KEY (name, qtype, class) + ) + `); err != nil { + db.Close() + return nil, err + } + + c.loadFromDB() + } + go c.evictLoop() + return c, nil +} + +func (c *Cache) Stop() { + close(c.stopCh) + if c.db != nil { + c.db.Close() + } +} + +func (c *Cache) Get(key Key) (*dns.Msg, bool) { + c.mu.RLock() + e, ok := c.entries[key] + if !ok || e.expired() { + c.mu.RUnlock() + atomic.AddInt64(&c.misses, 1) + return nil, false + } + c.mu.RUnlock() + + atomic.AddInt64(&c.hits, 1) + msg := e.msg.Copy() + remaining := e.remaining() + adjustTTL(msg, e.ttl, remaining) + return msg, true +} + +func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { + if ttl <= 0 { + ttl = computeTTL(msg) + } + if ttl <= 0 { + ttl = 60 * time.Second + } + + e := &entry{ + msg: msg.Copy(), + storedAt: time.Now(), + ttl: ttl, + } + + c.mu.Lock() + if len(c.entries) >= c.maxSize { + c.evictLocked() + } + c.entries[key] = e + c.mu.Unlock() + + if c.db != nil { + c.writeToDB(key, e) + } +} + +func (c *Cache) Stats() (int64, int64, int64) { + return atomic.LoadInt64(&c.hits), + atomic.LoadInt64(&c.misses), + atomic.LoadInt64(&c.evicted) +} + +func (c *Cache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.entries) +} + +func (c *Cache) evictLocked() { + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + if len(c.entries) < c.maxSize { + return + } + + var oldestKey Key + var oldestTime time.Time + first := true + for k, e := range c.entries { + if first || e.storedAt.Before(oldestTime) { + oldestKey = k + oldestTime = e.storedAt + first = false + } + } + if !first { + delete(c.entries, oldestKey) + atomic.AddInt64(&c.evicted, 1) + } +} + +func (c *Cache) evictLoop() { + tk := time.NewTicker(1 * time.Minute) + defer tk.Stop() + for { + select { + case <-tk.C: + c.mu.Lock() + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + c.mu.Unlock() + case <-c.stopCh: + return + } + } +} +func (c *Cache) writeToDB(key Key, e *entry) { + data, err := e.msg.Pack() + if err != nil { + return + } + c.db.Exec( + `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) + VALUES (?, ?, ?, ?, ?, ?)`, + key.Name, key.Qtype, key.Class, data, + e.storedAt.UnixNano(), int64(e.ttl), + ) +} + +func (c *Cache) loadFromDB() { + rows, err := c.db.Query( + `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`, + ) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var name string + var qtype, class uint16 + var data []byte + var storedAtNano, ttlNs int64 + + if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil { + continue + } + + msg := new(dns.Msg) + if err := msg.Unpack(data); err != nil { + continue + } + + e := &entry{ + msg: msg, + storedAt: time.Unix(0, storedAtNano), + ttl: time.Duration(ttlNs), + } + + if !e.expired() { + c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e + } else { + c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class) + } + } +} + + +func computeTTL(msg *dns.Msg) time.Duration { + if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + } + + var min uint32 + first := true + for _, rr := range msg.Answer { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Ns { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + if !first { + return time.Duration(min) * time.Second + } + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + return 0 +} + +func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { + ratio := float64(remaining) / float64(originalTTL) + for _, rr := range msg.Answer { + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Ns { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } +} + +func applyRatio(ttl uint32, ratio float64) uint32 { + return uint32(float64(ttl) * ratio) +} |
