package cache import ( "database/sql" "log/slog" "sync" "sync/atomic" "time" "codeberg.org/miekg/dns" _ "modernc.org/sqlite" ) type Key struct { Name string Qtype uint16 Class uint16 } type entry struct { packed []byte storedAt time.Time ttl time.Duration } func (e *entry) expired() bool { return time.Since(e.storedAt) >= e.ttl } func (e *entry) remaining() time.Duration { d := e.ttl - time.Since(e.storedAt) if d < 0 { return 0 } return d } type Cache struct { mu sync.RWMutex entries map[Key]*entry maxSize int db *sql.DB dbCh chan dbWrite wg sync.WaitGroup hits int64 misses int64 evicted int64 stopCh chan struct{} } type dbWrite struct { key Key e *entry } func NewCache(maxSize int, dbPath string) (*Cache, error) { if maxSize <= 0 { maxSize = 100000 } c := &Cache{ entries: make(map[Key]*entry), maxSize: maxSize, stopCh: make(chan struct{}), } if dbPath != "" { db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, err } c.db = db db.Exec("PRAGMA journal_mode=WAL") db.Exec("PRAGMA synchronous=NORMAL") db.Exec("PRAGMA cache_size=-65536") if _, err := db.Exec(` CREATE TABLE IF NOT EXISTS cache ( name TEXT NOT NULL, qtype INTEGER NOT NULL, class INTEGER NOT NULL, data BLOB NOT NULL, stored_at INTEGER NOT NULL, ttl_ns INTEGER NOT NULL, PRIMARY KEY (name, qtype, class) ) `); err != nil { db.Close() return nil, err } c.loadFromDB() c.dbCh = make(chan dbWrite, 1024) c.wg.Add(1) go c.dbWriter() } go c.evictLoop() return c, nil } func (c *Cache) Stop() { close(c.stopCh) if c.db != nil { close(c.dbCh) } c.wg.Wait() if c.db != nil { c.db.Close() } } func (c *Cache) Get(key Key) ([]byte, bool) { c.mu.RLock() e, ok := c.entries[key] if !ok || e.expired() { c.mu.RUnlock() atomic.AddInt64(&c.misses, 1) return nil, false } c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) cp := make([]byte, len(e.packed)) copy(cp, e.packed) cp[0] = 0 cp[1] = 0 return cp, true } func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { if ttl <= 0 { ttl = computeTTL(msg) } if ttl <= 0 { ttl = 60 * time.Second } if err := msg.Pack(); err != nil { return } cp := make([]byte, len(msg.Data)) copy(cp, msg.Data) e := &entry{ packed: cp, storedAt: time.Now(), ttl: ttl, } c.mu.Lock() if len(c.entries) >= c.maxSize { c.evictLocked() } c.entries[key] = e c.mu.Unlock() if c.db != nil { select { case c.dbCh <- dbWrite{key: key, e: e}: default: } } } func (c *Cache) Stats() (int64, int64, int64) { return atomic.LoadInt64(&c.hits), atomic.LoadInt64(&c.misses), atomic.LoadInt64(&c.evicted) } func (c *Cache) Len() int { c.mu.RLock() defer c.mu.RUnlock() return len(c.entries) } func (c *Cache) evictLocked() { now := time.Now() for k, e := range c.entries { if now.After(e.storedAt.Add(e.ttl)) { delete(c.entries, k) atomic.AddInt64(&c.evicted, 1) } } if len(c.entries) < c.maxSize { return } var oldestKey Key var oldestTime time.Time first := true for k, e := range c.entries { if first || e.storedAt.Before(oldestTime) { oldestKey = k oldestTime = e.storedAt first = false } } if !first { delete(c.entries, oldestKey) atomic.AddInt64(&c.evicted, 1) } } func (c *Cache) evictLoop() { tk := time.NewTicker(1 * time.Minute) defer tk.Stop() for { select { case <-tk.C: now := time.Now() c.mu.RLock() var expired []Key for k, e := range c.entries { if now.After(e.storedAt.Add(e.ttl)) { expired = append(expired, k) } } c.mu.RUnlock() if len(expired) > 0 { c.mu.Lock() for _, k := range expired { if _, ok := c.entries[k]; ok { delete(c.entries, k) atomic.AddInt64(&c.evicted, 1) } } c.mu.Unlock() } case <-c.stopCh: return } } } func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } _, err := c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) VALUES (?,?,?,?,?,?)`, key.Name, key.Qtype, key.Class, e.packed, e.storedAt.UnixNano(), int64(e.ttl), ) if err != nil { slog.Warn("cache write to db failed", "err", err) } } func (c *Cache) dbWriter() { defer c.wg.Done() for w := range c.dbCh { c.writeToDB(w.key, w.e) } } func (c *Cache) loadFromDB() { rows, err := c.db.Query( `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`, ) if err != nil { return } defer rows.Close() for rows.Next() { var name string var qtype, class uint16 var data []byte var storedAtNano, ttlNs int64 if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil { continue } e := &entry{ packed: data, storedAt: time.Unix(0, storedAtNano), ttl: time.Duration(ttlNs), } if !e.expired() { c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e } else { c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class) } } } func computeTTL(msg *dns.Msg) time.Duration { if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { for _, rr := range msg.Ns { if soa, ok := rr.(*dns.SOA); ok { return time.Duration(soa.Minttl) * time.Second } } } var min uint32 first := true for _, rr := range msg.Answer { if first || rr.Header().TTL < min { min = rr.Header().TTL first = false } } for _, rr := range msg.Ns { if first || rr.Header().TTL < min { min = rr.Header().TTL first = false } } for _, rr := range msg.Extra { if dns.RRToType(rr) == dns.TypeOPT { continue } if first || rr.Header().TTL < min { min = rr.Header().TTL first = false } } if !first { return time.Duration(min) * time.Second } for _, rr := range msg.Ns { if soa, ok := rr.(*dns.SOA); ok { return time.Duration(soa.Minttl) * time.Second } } return 0 }