diff options
| author | radhitya <alif@radhitya.org> | 2026-06-21 09:48:42 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-21 09:48:42 +0700 |
| commit | b7359e1d45f505171356bcae3c7d5e2341ecc859 (patch) | |
| tree | f91d4a4b08ce279d488a76e9b7141e69fc844ea9 /internal/cache | |
| parent | 2b1f613c42de3861141eb6f93c1740b6937ee183 (diff) | |
forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply
Diffstat (limited to 'internal/cache')
| -rw-r--r-- | internal/cache/cache.go | 159 | ||||
| -rw-r--r-- | internal/cache/cache_test.go | 20 |
2 files changed, 46 insertions, 133 deletions
diff --git a/internal/cache/cache.go b/internal/cache/cache.go index bb35b8e..b0a0959 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -12,18 +12,18 @@ import ( ) type Key struct { - Name string + Name string Qtype uint16 Class uint16 } type entry struct { - msg *dns.Msg + packed []byte storedAt time.Time - ttl time.Duration + ttl time.Duration } -func (e *entry) expired() bool { +func (e *entry) expired() bool { return time.Since(e.storedAt) >= e.ttl } @@ -36,20 +36,20 @@ func (e *entry) remaining() time.Duration { } type Cache struct { - mu sync.RWMutex + mu sync.RWMutex entries map[Key]*entry maxSize int - db *sql.DB - dbCh chan dbWrite - wg sync.WaitGroup - hits int64 - misses int64 + db *sql.DB + dbCh chan dbWrite + wg sync.WaitGroup + hits int64 + misses int64 evicted int64 - stopCh chan struct{} + stopCh chan struct{} } type dbWrite struct { key Key - e *entry + e *entry } func NewCache(maxSize int, dbPath string) (*Cache, error) { @@ -59,7 +59,7 @@ func NewCache(maxSize int, dbPath string) (*Cache, error) { c := &Cache{ entries: make(map[Key]*entry), maxSize: maxSize, - stopCh: make(chan struct{}), + stopCh: make(chan struct{}), } if dbPath != "" { @@ -108,7 +108,7 @@ func (c *Cache) Stop() { } } -func (c *Cache) Get(key Key) (*dns.Msg, bool) { +func (c *Cache) Get(key Key) ([]byte, bool) { c.mu.RLock() e, ok := c.entries[key] if !ok || e.expired() { @@ -119,10 +119,11 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) { c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) - msg := deepCopyMsg(e.msg) - remaining := e.remaining() - adjustTTL(msg, e.ttl, remaining) - return msg, true + cp := make([]byte, len(e.packed)) + copy(cp, e.packed) + cp[0] = 0 + cp[1] = 0 + return cp, true } func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { @@ -133,10 +134,16 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { ttl = 60 * time.Second } + if err := msg.Pack(); err != nil { + return + } + cp := make([]byte, len(msg.Data)) + copy(cp, msg.Data) + e := &entry{ - msg: deepCopyMsg(msg), + packed: cp, storedAt: time.Now(), - ttl: ttl, + ttl: ttl, } c.mu.Lock() @@ -156,8 +163,8 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { func (c *Cache) Stats() (int64, int64, int64) { return atomic.LoadInt64(&c.hits), - atomic.LoadInt64(&c.misses), - atomic.LoadInt64(&c.evicted) + atomic.LoadInt64(&c.misses), + atomic.LoadInt64(&c.evicted) } func (c *Cache) Len() int { @@ -228,15 +235,9 @@ func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } - err := e.msg.Pack() - if err != nil { - return - } - data := e.msg.Data - _, err = c.db.Exec( + _, err := c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) - VALUES (?, ?, ?, ?, ?, ?)`, - key.Name, key.Qtype, key.Class, data, + VALUES (?,?,?,?,?,?)`, key.Name, key.Qtype, key.Class, e.packed, e.storedAt.UnixNano(), int64(e.ttl), ) if err != nil { @@ -269,14 +270,8 @@ func (c *Cache) loadFromDB() { continue } - msg := new(dns.Msg) - msg.Data = data - if err := msg.Unpack(); err != nil { - continue - } - e := &entry{ - msg: msg, + packed: data, storedAt: time.Unix(0, storedAtNano), ttl: time.Duration(ttlNs), } @@ -289,7 +284,6 @@ func (c *Cache) loadFromDB() { } } - func computeTTL(msg *dns.Msg) time.Duration { if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { for _, rr := range msg.Ns { @@ -332,92 +326,3 @@ func computeTTL(msg *dns.Msg) time.Duration { } return 0 } - -func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { - ratio := float64(remaining) / float64(originalTTL) - for _, rr := range msg.Answer { - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } - for _, rr := range msg.Ns { - if dns.RRToType(rr) == dns.TypeOPT { - continue - } - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } - for _, rr := range msg.Extra { - if dns.RRToType(rr) == dns.TypeOPT { - continue - } - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } -} - -func applyRatio(ttl uint32, ratio float64) uint32 { - return uint32(float64(ttl) * ratio) -} - -func deepCopyMsg(msg *dns.Msg) *dns.Msg { - cp := new(dns.Msg) - cp.Response = msg.Response - cp.ID = msg.ID - cp.Opcode = msg.Opcode - cp.Authoritative = msg.Authoritative - cp.Truncated = msg.Truncated - cp.RecursionDesired = msg.RecursionDesired - cp.RecursionAvailable = msg.RecursionAvailable - cp.Rcode = msg.Rcode - cp.UDPSize = msg.UDPSize - cp.Question = copyRRSlice(msg.Question) - cp.Answer = copyRRSlice(msg.Answer) - cp.Ns = copyRRSlice(msg.Ns) - cp.Extra = copyRRSlice(msg.Extra) - return cp -} - -func copyRRSlice(src []dns.RR) []dns.RR { - if src == nil { - return nil - } - dst := make([]dns.RR, len(src)) - for i, rr := range src { - dst[i] = copyRR(rr) - } - return dst -} - -func copyRR(rr dns.RR) dns.RR { - if rr == nil { - return nil - } - switch v := rr.(type) { - case *dns.A: - cp := *v - return &cp - case *dns.AAAA: - cp := *v - return &cp - case *dns.NS: - cp := *v - return &cp - case *dns.CNAME: - cp := *v - return &cp - case *dns.SOA: - cp := *v - return &cp - case *dns.MX: - cp := *v - return &cp - case *dns.TXT: - cp := *v - return &cp - case *dns.SRV: - cp := *v - return &cp - case *dns.PTR: - cp := *v - return &cp - default: - return rr - } -} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 33e1d30..aa0aeb0 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -26,14 +26,18 @@ func TestSetGet(t *testing.T) { key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} c.Set(key, msg, 300*time.Second) - got, ok := c.Get(key) + packed, ok := c.Get(key) if !ok { t.Fatal("expected cache hit") } - if len(got.Answer) != 1 { - t.Fatalf("expected 1 answer, got %d", len(got.Answer)) + msg.Data = packed + if err := msg.Unpack(); err != nil { + t.Fatal(err) + } + if len(msg.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(msg.Answer)) } - a, _ := got.Answer[0].(*dns.A) + a, _ := msg.Answer[0].(*dns.A) if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } @@ -130,11 +134,15 @@ func TestSQLitePersistence(t *testing.T) { } defer c2.Stop() - got, ok := c2.Get(key) + packed, ok := c2.Get(key) if !ok { t.Fatal("expected cache hit from SQLite load") } - a, _ := got.Answer[0].(*dns.A) + msg.Data = packed + if err := msg.Unpack(); err != nil { + t.Fatal(err) + } + a, _ := msg.Answer[0].(*dns.A) if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } |
