summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
committerradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
commit4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch)
tree12d9700e53775503ad7ba2beb72bedfc64bdd70d
parent3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff)
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
-rw-r--r--.gitignore1
-rw-r--r--go.sum29
-rw-r--r--internal/blocklist/blocklist.go181
-rw-r--r--internal/blocklist/blocklist_test.go78
-rw-r--r--internal/cache/cache.go317
-rw-r--r--internal/cache/cache_test.go140
-rw-r--r--internal/server/doh.go2
-rw-r--r--internal/server/handler.go81
-rw-r--r--internal/server/server.go9
-rw-r--r--internal/server/server_test.go6
-rw-r--r--main.go32
11 files changed, 849 insertions, 27 deletions
diff --git a/.gitignore b/.gitignore
index 0922409..514e3b6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
+etc/blocklist/*.txt
sdns
todo.md
diff --git a/go.sum b/go.sum
index 92364b3..b600bb8 100644
--- a/go.sum
+++ b/go.sum
@@ -1,14 +1,43 @@
+github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
+github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
+github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
+github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
+golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
+golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
+golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
+golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
+modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
+modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
+modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
+modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo=
+modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go
new file mode 100644
index 0000000..ae6d20f
--- /dev/null
+++ b/internal/blocklist/blocklist.go
@@ -0,0 +1,181 @@
+package blocklist
+
+import (
+ "bufio"
+ "net/http"
+ "os"
+ "strings"
+ "sync"
+ "sync/atomic"
+)
+
+type ResponseAction int
+
+const (
+ ResponseZeroIP ResponseAction = iota
+ ResponseNXDOMAIN
+)
+
+type Blocklist struct {
+ mu sync.RWMutex
+ blocked *trie
+ exceptions *trie
+ response ResponseAction
+ TotalRules int32
+ Hits int64
+}
+
+func New(action ResponseAction) *Blocklist {
+ return &Blocklist{
+ blocked: newTrie(),
+ exceptions: newTrie(),
+ response: action,
+ }
+}
+
+func (b *Blocklist) Response() ResponseAction {
+ return b.response
+}
+
+func (b *Blocklist) IsBlocked(domain string) bool {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ labels := splitDomain(domain)
+ if !b.blocked.match(labels) {
+ return false
+ }
+ if b.exceptions.match(labels) {
+ return false
+ }
+ atomic.AddInt64(&b.Hits, 1)
+ return true
+}
+
+func (b *Blocklist) LoadFile(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ scanner := bufio.NewScanner(f)
+ var n int32
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || line[0] == '#' || line[0] == '!' {
+ continue
+ }
+ if b.addRule(line) {
+ n++
+ }
+ }
+ atomic.StoreInt32(&b.TotalRules, n)
+ return scanner.Err()
+}
+
+func (b *Blocklist) LoadURL(url string) error {
+ resp, err := http.Get(url)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ scanner := bufio.NewScanner(resp.Body)
+ var n int32
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || line[0] == '#' {
+ continue
+ }
+ if b.addRule(line) {
+ n++
+ }
+ }
+ atomic.AddInt32(&b.TotalRules, n)
+ return scanner.Err()
+}
+
+func (b *Blocklist) addRule(line string) bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if strings.HasPrefix(line, "@@") {
+ domain := strings.TrimPrefix(line, "@@")
+ domain = strings.TrimPrefix(domain, "||")
+ if idx := strings.Index(domain, "^"); idx > 0 {
+ domain = domain[:idx]
+ }
+ b.exceptions.insert(splitDomain(domain))
+ return true
+ }
+
+ fields := strings.Fields(line)
+ if len(fields) >= 2 {
+ ip := fields[0]
+ if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" {
+ b.blocked.insert(splitDomain(fields[len(fields)-1]))
+ return true
+ }
+ }
+ if strings.HasPrefix(line, "||") {
+ domain := strings.TrimPrefix(line, "||")
+ if idx := strings.Index(domain, "^"); idx > 0 {
+ domain = domain[:idx]
+ }
+ b.blocked.insert(splitDomain(domain))
+ return true
+ }
+ if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") {
+ b.blocked.insert(splitDomain(line))
+ return true
+ }
+ return false
+}
+
+func splitDomain(domain string) []string {
+ domain = strings.TrimSuffix(domain, ".")
+ return strings.Split(domain, ".")
+}
+
+type trieNode struct {
+ children map[string]*trieNode
+ terminal bool
+}
+
+type trie struct{ root *trieNode }
+
+func newTrie() *trie {
+ return &trie{
+ root: &trieNode{
+ children: make(map[string]*trieNode),
+ },
+ }
+}
+func (t *trie) insert(labels []string) {
+ node := t.root
+ for i := len(labels) - 1; i >= 0; i-- {
+ child, ok := node.children[labels[i]]
+ if !ok {
+ child = &trieNode{children: make(map[string]*trieNode)}
+ node.children[labels[i]] = child
+ }
+ node = child
+ }
+ node.terminal = true
+}
+
+func (t *trie) match(labels []string) bool {
+ node := t.root
+ for i := len(labels) - 1; i >= 0; i-- {
+ if node.terminal {
+ return true
+ }
+ child, ok := node.children[labels[i]]
+ if !ok {
+ return false
+ }
+ node = child
+ }
+ return node.terminal
+}
diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go
new file mode 100644
index 0000000..9ae1749
--- /dev/null
+++ b/internal/blocklist/blocklist_test.go
@@ -0,0 +1,78 @@
+package blocklist
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestBlockZeroIP(t *testing.T) {
+ b := New(ResponseZeroIP)
+ b.addRule("||example.com^")
+ b.addRule("0.0.0.0 doubleclick.net")
+
+ cases := []struct {
+ domain string
+ block bool
+ }{
+ {"example.com.", true},
+ {"sub.example.com.", true},
+ {"notexample.com.", false},
+ {"doubleclick.net.", true},
+ {"ads.doubleclick.net.", true},
+ {"google.com.", false},
+ }
+ for _, c := range cases {
+ got := b.IsBlocked(c.domain)
+ if got != c.block {
+ t.Errorf("IsBlocked(%q) = %v, want %v", c.domain, got, c.block)
+ }
+ }
+}
+
+func TestException(t *testing.T) {
+ b := New(ResponseZeroIP)
+ b.addRule("||example.com^")
+ b.addRule("@@||whitelist.example.com^")
+
+ if !b.IsBlocked("example.com.") {
+ t.Error("expected blocked")
+ }
+ if b.IsBlocked("whitelist.example.com.") {
+ t.Error("expected NOT blocked (exception)")
+ }
+ if b.IsBlocked("sub.whitelist.example.com.") {
+ t.Error("expected NOT blocked (exception subdomain)")
+ }
+}
+
+func TestLoadFile(t *testing.T) {
+ dir := t.TempDir()
+ p := filepath.Join(dir, "block.txt")
+ os.WriteFile(p, []byte("0.0.0.0 ads.com\n||tracker.net^\n"), 0644)
+
+ b := New(ResponseZeroIP)
+ if err := b.LoadFile(p); err != nil {
+ t.Fatal(err)
+ }
+ if !b.IsBlocked("ads.com.") {
+ t.Error("expected blocked")
+ }
+ if !b.IsBlocked("sub.tracker.net.") {
+ t.Error("expected blocked")
+ }
+ if b.IsBlocked("safe.org.") {
+ t.Error("expected NOT blocked")
+ }
+}
+
+func TestResponseType(t *testing.T) {
+ b := New(ResponseNXDOMAIN)
+ if b.Response() != ResponseNXDOMAIN {
+ t.Error("expected NXDOMAIN")
+ }
+ b2 := New(ResponseZeroIP)
+ if b2.Response() != ResponseZeroIP {
+ t.Error("expected ZeroIP")
+ }
+}
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
new file mode 100644
index 0000000..a2d86a0
--- /dev/null
+++ b/internal/cache/cache.go
@@ -0,0 +1,317 @@
+package cache
+
+import (
+ "database/sql"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/dns"
+ _ "modernc.org/sqlite"
+)
+
+type Key struct {
+ Name string
+ Qtype uint16
+ Class uint16
+}
+
+type entry struct {
+ msg *dns.Msg
+ storedAt time.Time
+ ttl time.Duration
+}
+
+func (e *entry) expired() bool {
+ return time.Since(e.storedAt) >= e.ttl
+}
+
+func (e *entry) remaining() time.Duration {
+ d := e.ttl - time.Since(e.storedAt)
+ if d < 0 {
+ return 0
+ }
+ return d
+}
+
+type Cache struct {
+ mu sync.RWMutex
+ entries map[Key]*entry
+ maxSize int
+ db *sql.DB
+
+ hits int64
+ misses int64
+ evicted int64
+ stopCh chan struct{}
+}
+
+func NewCache(maxSize int, dbPath string) (*Cache, error) {
+ if maxSize <= 0 {
+ maxSize = 100000
+ }
+ c := &Cache{
+ entries: make(map[Key]*entry),
+ maxSize: maxSize,
+ stopCh: make(chan struct{}),
+ }
+
+ if dbPath != "" {
+ db, err := sql.Open("sqlite", dbPath)
+ if err != nil {
+ return nil, err
+ }
+ c.db = db
+
+ db.Exec("PRAGMA journal_mode=WAL")
+ db.Exec("PRAGMA synchronous=NORMAL")
+ db.Exec("PRAGMA cache_size=-65536")
+
+ if _, err := db.Exec(`
+ CREATE TABLE IF NOT EXISTS cache (
+ name TEXT NOT NULL,
+ qtype INTEGER NOT NULL,
+ class INTEGER NOT NULL,
+ data BLOB NOT NULL,
+ stored_at INTEGER NOT NULL,
+ ttl_ns INTEGER NOT NULL,
+ PRIMARY KEY (name, qtype, class)
+ )
+ `); err != nil {
+ db.Close()
+ return nil, err
+ }
+
+ c.loadFromDB()
+ }
+ go c.evictLoop()
+ return c, nil
+}
+
+func (c *Cache) Stop() {
+ close(c.stopCh)
+ if c.db != nil {
+ c.db.Close()
+ }
+}
+
+func (c *Cache) Get(key Key) (*dns.Msg, bool) {
+ c.mu.RLock()
+ e, ok := c.entries[key]
+ if !ok || e.expired() {
+ c.mu.RUnlock()
+ atomic.AddInt64(&c.misses, 1)
+ return nil, false
+ }
+ c.mu.RUnlock()
+
+ atomic.AddInt64(&c.hits, 1)
+ msg := e.msg.Copy()
+ remaining := e.remaining()
+ adjustTTL(msg, e.ttl, remaining)
+ return msg, true
+}
+
+func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) {
+ if ttl <= 0 {
+ ttl = computeTTL(msg)
+ }
+ if ttl <= 0 {
+ ttl = 60 * time.Second
+ }
+
+ e := &entry{
+ msg: msg.Copy(),
+ storedAt: time.Now(),
+ ttl: ttl,
+ }
+
+ c.mu.Lock()
+ if len(c.entries) >= c.maxSize {
+ c.evictLocked()
+ }
+ c.entries[key] = e
+ c.mu.Unlock()
+
+ if c.db != nil {
+ c.writeToDB(key, e)
+ }
+}
+
+func (c *Cache) Stats() (int64, int64, int64) {
+ return atomic.LoadInt64(&c.hits),
+ atomic.LoadInt64(&c.misses),
+ atomic.LoadInt64(&c.evicted)
+}
+
+func (c *Cache) Len() int {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return len(c.entries)
+}
+
+func (c *Cache) evictLocked() {
+ now := time.Now()
+ for k, e := range c.entries {
+ if now.After(e.storedAt.Add(e.ttl)) {
+ delete(c.entries, k)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+ }
+ if len(c.entries) < c.maxSize {
+ return
+ }
+
+ var oldestKey Key
+ var oldestTime time.Time
+ first := true
+ for k, e := range c.entries {
+ if first || e.storedAt.Before(oldestTime) {
+ oldestKey = k
+ oldestTime = e.storedAt
+ first = false
+ }
+ }
+ if !first {
+ delete(c.entries, oldestKey)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+}
+
+func (c *Cache) evictLoop() {
+ tk := time.NewTicker(1 * time.Minute)
+ defer tk.Stop()
+ for {
+ select {
+ case <-tk.C:
+ c.mu.Lock()
+ now := time.Now()
+ for k, e := range c.entries {
+ if now.After(e.storedAt.Add(e.ttl)) {
+ delete(c.entries, k)
+ atomic.AddInt64(&c.evicted, 1)
+ }
+ }
+ c.mu.Unlock()
+ case <-c.stopCh:
+ return
+ }
+ }
+}
+func (c *Cache) writeToDB(key Key, e *entry) {
+ data, err := e.msg.Pack()
+ if err != nil {
+ return
+ }
+ c.db.Exec(
+ `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns)
+ VALUES (?, ?, ?, ?, ?, ?)`,
+ key.Name, key.Qtype, key.Class, data,
+ e.storedAt.UnixNano(), int64(e.ttl),
+ )
+}
+
+func (c *Cache) loadFromDB() {
+ rows, err := c.db.Query(
+ `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`,
+ )
+ if err != nil {
+ return
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var name string
+ var qtype, class uint16
+ var data []byte
+ var storedAtNano, ttlNs int64
+
+ if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil {
+ continue
+ }
+
+ msg := new(dns.Msg)
+ if err := msg.Unpack(data); err != nil {
+ continue
+ }
+
+ e := &entry{
+ msg: msg,
+ storedAt: time.Unix(0, storedAtNano),
+ ttl: time.Duration(ttlNs),
+ }
+
+ if !e.expired() {
+ c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e
+ } else {
+ c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class)
+ }
+ }
+}
+
+
+func computeTTL(msg *dns.Msg) time.Duration {
+ if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 {
+ for _, rr := range msg.Ns {
+ if soa, ok := rr.(*dns.SOA); ok {
+ return time.Duration(soa.Minttl) * time.Second
+ }
+ }
+ }
+
+ var min uint32
+ first := true
+ for _, rr := range msg.Answer {
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ for _, rr := range msg.Ns {
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ for _, rr := range msg.Extra {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ if first || rr.Header().Ttl < min {
+ min = rr.Header().Ttl
+ first = false
+ }
+ }
+ if !first {
+ return time.Duration(min) * time.Second
+ }
+ for _, rr := range msg.Ns {
+ if soa, ok := rr.(*dns.SOA); ok {
+ return time.Duration(soa.Minttl) * time.Second
+ }
+ }
+ return 0
+}
+
+func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) {
+ ratio := float64(remaining) / float64(originalTTL)
+ for _, rr := range msg.Answer {
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+ for _, rr := range msg.Ns {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+ for _, rr := range msg.Extra {
+ if rr.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio)
+ }
+}
+
+func applyRatio(ttl uint32, ratio float64) uint32 {
+ return uint32(float64(ttl) * ratio)
+}
diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go
new file mode 100644
index 0000000..6556dcb
--- /dev/null
+++ b/internal/cache/cache_test.go
@@ -0,0 +1,140 @@
+package cache
+
+import (
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+func TestSetGet(t *testing.T) {
+ c, err := NewCache(100, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Stop()
+
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300},
+ A: net.IPv4(1,2,3,4),
+ })
+
+ key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 300*time.Second)
+
+ got, ok := c.Get(key)
+ if !ok {
+ t.Fatal("expected cache hit")
+ }
+ if len(got.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(got.Answer))
+ }
+ a, _ := got.Answer[0].(*dns.A)
+ if !a.A.Equal(net.IPv4(1,2,3,4)) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ }
+}
+
+func TestExpiry(t *testing.T) {
+ c, _ := NewCache(10, "")
+ defer c.Stop()
+ msg := new(dns.Msg)
+ key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 30*time.Millisecond)
+ time.Sleep(60 * time.Millisecond)
+ _, ok := c.Get(key)
+ if ok {
+ t.Fatal("expected miss after expiry")
+ }
+}
+
+func TestEviction(t *testing.T) {
+ c, _ := NewCache(2, "")
+ defer c.Stop()
+ for i := 0; i < 5; i++ {
+ name := fmt.Sprintf("d%d.com.", i)
+ msg := new(dns.Msg)
+ c.Set(Key{Name: name, Qtype: dns.TypeA, Class: dns.ClassINET}, msg, 60*time.Second)
+ }
+ if c.Len() > 2 {
+ t.Errorf("expected ≤2 entries, got %d", c.Len())
+ }
+}
+
+func TestComputeTTL(t *testing.T) {
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120},
+ })
+ if d := computeTTL(msg); d != 120*time.Second {
+ t.Errorf("TTL = %v, want 120s", d)
+ }
+}
+
+func TestNegativeTTL(t *testing.T) {
+ msg := new(dns.Msg)
+ msg.Rcode = dns.RcodeNameError
+ msg.Ns = append(msg.Ns, &dns.SOA{
+ Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900},
+ Minttl: 300,
+ })
+ if d := computeTTL(msg); d != 300*time.Second {
+ t.Errorf("negative TTL = %v, want 300s", d)
+ }
+}
+
+func TestRace(t *testing.T) {
+ c, _ := NewCache(1000, "")
+ defer c.Stop()
+ done := make(chan struct{})
+ go func() {
+ for i := 0; i < 100; i++ {
+ msg := new(dns.Msg)
+ c.Set(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}, msg, time.Second)
+ }
+ close(done)
+ }()
+ for i := 0; i < 100; i++ {
+ c.Get(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET})
+ }
+ <-done
+ c.Stats()
+ c.Len()
+}
+
+func TestSQLitePersistence(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := dir + "/cache.db"
+
+ c, err := NewCache(100, dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ msg := new(dns.Msg)
+ msg.Answer = append(msg.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300},
+ A: net.IPv4(1, 2, 3, 4),
+ })
+ key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET}
+ c.Set(key, msg, 300*time.Second)
+ c.Stop()
+
+ c2, err := NewCache(100, dbPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Stop()
+
+ got, ok := c2.Get(key)
+ if !ok {
+ t.Fatal("expected cache hit from SQLite load")
+ }
+ a, _ := got.Answer[0].(*dns.A)
+ if !a.A.Equal(net.IPv4(1, 2, 3, 4)) {
+ t.Errorf("IP = %s, want 1.2.3.4", a.A)
+ }
+}
diff --git a/internal/server/doh.go b/internal/server/doh.go
index 3f5a538..b46736e 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
return
}
- resp := s.buildResponse(msg)
+ resp, _ := s.buildResponse(msg)
packed, err := resp.Pack()
if err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
diff --git a/internal/server/handler.go b/internal/server/handler.go
index bac1c81..4aa771f 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -2,9 +2,13 @@ package server
import (
"context"
- "github.com/miekg/dns"
"log/slog"
+ "net"
"time"
+
+ "github.com/miekg/dns"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
)
func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
@@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
return
}
- resp := s.buildResponse(req)
+ resp, blocked := s.buildResponse(req)
if err := w.WriteMsg(resp); err != nil {
slog.Error("write response failed",
- "err", err,
+ "err", err,
"qname", req.Question[0].Name,
"qtype", dns.TypeToString[req.Question[0].Qtype],
)
return
}
slog.Info("query served",
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
- "rcode", dns.RcodeToString[resp.Rcode],
- "client", w.RemoteAddr().String(),
+ "qname", req.Question[0].Name,
+ "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ "client", w.RemoteAddr().String(),
+ "blocked", blocked,
)
}
-func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
+func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
- m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- return m
+ return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false
}
q := req.Question[0]
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Authoritative = false
- if opt := req.IsEdns0(); opt != nil {
- resp.SetEdns0(4096, false)
+ if s.cache != nil {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ if cached, ok := s.cache.Get(key); ok {
+ return cached, false
+ }
+ }
+
+ if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) {
+ return s.blockedResponse(req), true
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", q.Name,
"qtype", dns.TypeToString[q.Qtype],
)
- resp.Rcode = dns.RcodeServerFailure
- return resp
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeServerFailure)
+ return m, false
+ }
+
+ reply.Id = req.Id
+
+ if s.cache != nil && reply.Rcode != dns.RcodeServerFailure {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ s.cache.Set(key, reply, 0)
+ }
+
+ return reply, false
+}
+
+func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ m.Authoritative = true
+
+ if s.blocklist.Response() == blocklist.ResponseNXDOMAIN {
+ m.Rcode = dns.RcodeNameError
+ return m
}
- return reply
+ q := req.Question[0]
+ switch q.Qtype {
+ case dns.TypeA:
+ m.Answer = append(m.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
+ A: net.IPv4(0, 0, 0, 0),
+ })
+ case dns.TypeAAAA:
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
+ AAAA: net.IPv6zero,
+ })
+ }
+ return m
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 3114073..e0490bd 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -8,18 +8,23 @@ import (
"github.com/miekg/dns"
"sdns/internal/resolver"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
)
type Server struct {
logger *slog.Logger
resolver *resolver.Resolver
+ cache *cache.Cache
+ blocklist *blocklist.Blocklist
udp *dns.Server
tcp *dns.Server
doh *http.Server
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) {
- s := &Server{logger: logger, resolver: r}
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
+r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
+ s := &Server{logger: logger, resolver: r, cache: c, blocklist: b}
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 6fc2092..c49d5f3 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -40,7 +40,7 @@ func TestBuildResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- resp := s.buildResponse(tt.req)
+ resp, _ := s.buildResponse(tt.req)
if resp.Rcode != tt.wantRcode {
t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode)
}
@@ -63,7 +63,7 @@ func TestBuildResponseWithQuery(t *testing.T) {
// Valid query → must not panic, rcode must be valid
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
- resp := s.buildResponse(m)
+ resp, _ := s.buildResponse(m)
if resp == nil {
t.Fatal("buildResponse returned nil")
}
@@ -100,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) {
if err := msg.Unpack(data); err != nil {
return
}
- resp := s.buildResponse(msg)
+ resp, _ := s.buildResponse(msg)
if resp == nil {
t.Fatal("buildResponse returned nil")
}
diff --git a/main.go b/main.go
index 2ddc86f..96c51cb 100644
--- a/main.go
+++ b/main.go
@@ -5,8 +5,11 @@ import (
"log/slog"
"os"
"os/signal"
+ "path/filepath"
"syscall"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
"sdns/internal/resolver"
"sdns/internal/server"
)
@@ -19,6 +22,33 @@ func main() {
r := resolver.New()
+ dbPath := os.Getenv("SDNS_CACHE_DB")
+ if dbPath != "" {
+ logger.Info("cache using sqlite", "path", dbPath)
+ } else {
+ logger.Info("cache using in-memory")
+ }
+ c, err := cache.NewCache(100000, dbPath)
+ if err != nil {
+ logger.Error("create cache failed", "err", err)
+ os.Exit(1)
+ }
+ defer c.Stop()
+
+ var bl *blocklist.Blocklist
+ matches, _ := filepath.Glob("etc/blocklist/*.txt")
+ if len(matches) > 0 {
+ bl = blocklist.New(blocklist.ResponseZeroIP)
+ for _, f := range matches {
+ if err := bl.LoadFile(f); err != nil {
+ logger.Error("load blocklist failed", "file", f, "err", err)
+ os.Exit(1)
+ }
+ logger.Info("blocklist loaded", "file", f, "rules", bl.TotalRules)
+ }
+ } else {
+ logger.Info("no blocklist files in etc/blocklist/, ad-blocking disabled")
+ }
udp := os.Getenv("SDNS_LISTEN_UDP")
if udp == "" {
udp = ":5353"
@@ -37,7 +67,7 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
- srv, err := server.New(udp, tcp, doh, logger, r)
+ srv, err := server.New(udp, tcp, doh, logger, r, c, bl)
if err != nil {
logger.Error("create server failed", "err", err)
os.Exit(1)