summaryrefslogtreecommitdiff
path: root/internal/cache
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
committerradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
commit4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch)
tree12d9700e53775503ad7ba2beb72bedfc64bdd70d /internal/cache
parent3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff)
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
Diffstat (limited to 'internal/cache')
-rw-r--r--internal/cache/cache.go317
-rw-r--r--internal/cache/cache_test.go140
2 files changed, 457 insertions, 0 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
new file mode 100644
index 0000000..a2d86a0
--- /dev/null
+++ b/internal/cache/cache.go
@@ -0,0 +1,317 @@
+package cache
+
+import (
+ "database/sql"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+ _ "modernc.org/sqlite"
+)
+
+type Key struct {
+ Name string
+ Qtype uint16
+ Class uint16
+}
+
+type entry struct {
+ msg *dns.Msg
+ storedAt time.Time
+ ttl time.Duration
+}
+
+func (e *entry) expired() bool {
+ return time.Since(e.storedAt) >= e.ttl
+}
+
+func (e *entry) remaining() time.Duration {
+ d := e.ttl - time.Since(e.storedAt)
+ if d < 0 {
+ return 0
+ }
+ return d
+}
+
+type Cache struct {
+ mu sync.RWMutex
+ entries map[Key]*entry
+ maxSize int
+ db *sql.DB
+
+ hits int64
+ misses int64
+ evicted int64
+ stopCh chan struct{}
+}
+
+func NewCache(maxSize int, dbPath string) (*Cache, error) {
+ if maxSize <= 0 {
+ maxSize = 100000
+ }
+ c := &Cache{
+ entries: make(map[Key]*entry),
+ maxSize: maxSize,
+ stopCh: make(chan struct{}),
+ }
+
+ if dbPath != "" {
+ db, err := sql.Open("sqlite", dbPath)
+ if err != nil {
+ return nil, err
+ }
+ c.db = db
+
+ db.Exec("PRAGMA journal_mode=WAL")
+ db.Exec("PRAGMA synchronous=NORMAL")
+ db.Exec("PRAGMA cache_size=-65536")
+
+ if _, err := db.Exec(`
+ CREATE TABLE IF NOT EXISTS cache (
+ name TEXT NOT NULL,
+ qtype INTEGER NOT NULL,
+ class INTEGER NOT NULL,
+ data BLOB NOT NULL,
+ stored_at INTEGER NOT NULL,
+ ttl_ns INTEGER NOT NULL,
+ PRIMARY KEY (name, qtype, class)
+ )
+ `); err != nil {
+ db.Close()
+ return nil, err
+ }
+
+ c.loadFromDB()
+ }
+ go c.evictLoop()
+ return c, nil
+}
+
+func (c *Cache) Stop() {
+ close(c.stopCh)
+ if c.db != nil {
+ c.db.Close()
+ }
+}
+
+func (c *Cache) Get(key Key) (*dns.Msg, bool) {
+ c.mu.RLock()
+ e, ok := c.entries[key]
+ if !ok || e.expired() {
+ c.mu.RUnlock()
+ atomic.AddInt64(&c.misses, 1)
+ return nil, false
+ }
+ c.mu.RUnlock()
+
+ atomic.AddInt64(&c.hits, 1)
+ msg := e.msg.Copy()
+ remaining := e.remaining()
+ adjustTTL(msg, e.ttl, remaining)
+ return msg, true
+}
+
+func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
+ if ttl <= 0 {
+ ttl = computeTTL(msg)
+ }
+ if ttl <= 0 {
+ ttl = 60 * time.Second
+ }
+
+ e := &entry{
+ msg: msg.Copy(),
+ storedAt: time.Now(),
+ ttl: ttl,
+ }
+
+ c.mu.Lock()
+ if len(c.entries) >= c.maxSize {
+ c.evictLocked()
+ }
+ c.entries[key] = e
+ c.mu.Unlock()
+
+ if c.db != nil {
+ c.writeToDB(key, e)
+ }
+}
+
+func (c *Cache) Stats() (int64, int64, int64) {
+ return atomic.LoadInt64(&c.hits),
+ atomic.LoadInt64(&c.misses),
+ atomic.LoadInt64(&c.evicted)
+}
+
+func (c *Cache) Len() int {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return len(c.entries)
+}
+
+func (c *Cache) evictLocked() {
+ now := time.Now()
+ for k, e := range c.entries {
+ if now.After(e.storedAt.Add(e.ttl)) {
+ delete(c.entries, k)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+ }
+ if len(c.entries) < c.maxSize {
+ return
+ }
+
+ var oldestKey Key
+ var oldestTime time.Time
+ first := true
+ for k, e := range c.entries {
+ if first || e.storedAt.Before(oldestTime) {
+ oldestKey = k
+ oldestTime = e.storedAt
+ first = false
+ }
+ }
+ if !first {
+ delete(c.entries, oldestKey)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+}
+
+func (c *Cache) evictLoop() {
+ tk := time.NewTicker(1 * time.Minute)
+ defer tk.Stop()
+ for {
+ select {
+ case <-tk.C:
+ c.mu.Lock()
+ now := time.Now()
+ for k, e := range c.entries {
+ if now.After(e.storedAt.Add(e.ttl)) {
+ delete(c.entries, k)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+ }
+ c.mu.Unlock()
+ case <-c.stopCh:
+ return
+ }
+ }
+}
+func (c *Cache) writeToDB(key Key, e *entry) {
+ data, err := e.msg.Pack()
+ if err != nil {
+ return
+ }
+ c.db.Exec(
+ `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
+ VALUES (?, ?, ?, ?, ?, ?)`,
+ key.Name, key.Qtype, key.Class, data,
+ e.storedAt.UnixNano(), int64(e.ttl),
+ )
+}
+
+func (c *Cache) loadFromDB() {
+ rows, err := c.db.Query(
+ `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`,
+ )
+ if err != nil {
+ return
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var name string
+ var qtype, class uint16
+ var data []byte
+ var storedAtNano, ttlNs int64
+
+ if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil {
+ continue
+ }
+
+ msg := new(dns.Msg)
+ if err := msg.Unpack(data); err != nil {
+ continue
+ }
+
+ e := &entry{
+ msg: msg,
+ storedAt: time.Unix(0, storedAtNano),
+ ttl: time.Duration(ttlNs),
+ }
+
+ if !e.expired() {
+ c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e
+ } else {
+ c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class)
+ }
+ }
+}
+
+
+func computeTTL(msg *dns.Msg) time.Duration {
+ if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 {
+ for _, rr := range msg.Ns {
+ if soa, ok := rr.(*dns.SOA); ok {
+ return time.Duration(soa.Minttl) * time.Second
+ }
+ }
+ }
+
+ var min uint32
+ first := true
+ for _, rr := range msg.Answer {
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ for _, rr := range msg.Ns {
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ for _, rr := range msg.Extra {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ if !first {
+ return time.Duration(min) * time.Second
+ }
+ for _, rr := range msg.Ns {
+ if soa, ok := rr.(*dns.SOA); ok {
+ return time.Duration(soa.Minttl) * time.Second
+ }
+ }
+ return 0
+}
+
+func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
+ ratio := float64(remaining) / float64(originalTTL)
+ for _, rr := range msg.Answer {
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+ for _, rr := range msg.Ns {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+ for _, rr := range msg.Extra {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+}
+
+func applyRatio(ttl uint32, ratio float64) uint32 {
+ return uint32(float64(ttl) * ratio)
+}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
new file mode 100644
index 0000000..6556dcb
--- /dev/null
+++ b/internal/cache/cache_test.go
@@ -0,0 +1,140 @@
+package cache
+
+import (
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+func TestSetGet(t *testing.T) {
+ c, err := NewCache(100, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Stop()
+
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300},
+ A: net.IPv4(1,2,3,4),
+ })
+
+ key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 300*time.Second)
+
+ got, ok := c.Get(key)
+ if !ok {
+ t.Fatal("expected cache hit")
+ }
+ if len(got.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(got.Answer))
+ }
+ a, _ := got.Answer[0].(*dns.A)
+ if !a.A.Equal(net.IPv4(1,2,3,4)) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ }
+}
+
+func TestExpiry(t *testing.T) {
+ c, _ := NewCache(10, "")
+ defer c.Stop()
+ msg := new(dns.Msg)
+ key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 30*time.Millisecond)
+ time.Sleep(60 * time.Millisecond)
+ _, ok := c.Get(key)
+ if ok {
+ t.Fatal("expected miss after expiry")
+ }
+}
+
+func TestEviction(t *testing.T) {
+ c, _ := NewCache(2, "")
+ defer c.Stop()
+ for i := 0; i < 5; i++ {
+ name := fmt.Sprintf("d%d.com.", i)
+ msg := new(dns.Msg)
+ c.Set(Key{Name: name, Qtype: dns.TypeA, Class: dns.ClassINET}, msg, 60*time.Second)
+ }
+ if c.Len() > 2 {
+ t.Errorf("expected ≤2 entries, got %d", c.Len())
+ }
+}
+
+func TestComputeTTL(t *testing.T) {
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120},
+ })
+ if d := computeTTL(msg); d != 120*time.Second {
+ t.Errorf("TTL = %v, want 120s", d)
+ }
+}
+
+func TestNegativeTTL(t *testing.T) {
+ msg := new(dns.Msg)
+ msg.Rcode = dns.RcodeNameError
+ msg.Ns = append(msg.Ns, &dns.SOA{
+ Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900},
+ Minttl: 300,
+ })
+ if d := computeTTL(msg); d != 300*time.Second {
+ t.Errorf("negative TTL = %v, want 300s", d)
+ }
+}
+
+func TestRace(t *testing.T) {
+ c, _ := NewCache(1000, "")
+ defer c.Stop()
+ done := make(chan struct{})
+ go func() {
+ for i := 0; i < 100; i++ {
+ msg := new(dns.Msg)
+ c.Set(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}, msg, time.Second)
+ }
+ close(done)
+ }()
+ for i := 0; i < 100; i++ {
+ c.Get(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET})
+ }
+ <-done
+ c.Stats()
+ c.Len()
+}
+
+func TestSQLitePersistence(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := dir + "/cache.db"
+
+ c, err := NewCache(100, dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300},
+ A: net.IPv4(1, 2, 3, 4),
+ })
+ key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 300*time.Second)
+ c.Stop()
+
+ c2, err := NewCache(100, dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Stop()
+
+ got, ok := c2.Get(key)
+ if !ok {
+ t.Fatal("expected cache hit from SQLite load")
+ }
+ a, _ := got.Answer[0].(*dns.A)
+ if !a.A.Equal(net.IPv4(1, 2, 3, 4)) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ }
+}