diff options
| author | radhitya <alif@radhitya.org> | 2026-06-14 14:36:32 +0700 |
|---|---|---|
| committer | radhitya <alif@radhitya.org> | 2026-06-14 14:36:32 +0700 |
| commit | 4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch) | |
| tree | 12d9700e53775503ad7ba2beb72bedfc64bdd70d | |
| parent | 3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff) | |
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | go.sum | 29 | ||||
| -rw-r--r-- | internal/blocklist/blocklist.go | 181 | ||||
| -rw-r--r-- | internal/blocklist/blocklist_test.go | 78 | ||||
| -rw-r--r-- | internal/cache/cache.go | 317 | ||||
| -rw-r--r-- | internal/cache/cache_test.go | 140 | ||||
| -rw-r--r-- | internal/server/doh.go | 2 | ||||
| -rw-r--r-- | internal/server/handler.go | 81 | ||||
| -rw-r--r-- | internal/server/server.go | 9 | ||||
| -rw-r--r-- | internal/server/server_test.go | 6 | ||||
| -rw-r--r-- | main.go | 32 |
11 files changed, 849 insertions, 27 deletions
@@ -1,2 +1,3 @@ +etc/blocklist/*.txt sdns todo.md @@ -1,14 +1,43 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU= +modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.52.0 h1:p4dhYh2tXZCiyaqHwRVJDjIGKWyXayiQpThxgDzJaxo= +modernc.org/sqlite v1.52.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM= diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go new file mode 100644 index 0000000..ae6d20f --- /dev/null +++ b/internal/blocklist/blocklist.go @@ -0,0 +1,181 @@ +package blocklist + +import ( + "bufio" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" +) + +type ResponseAction int + +const ( + ResponseZeroIP ResponseAction = iota + ResponseNXDOMAIN +) + +type Blocklist struct { + mu sync.RWMutex + blocked *trie + exceptions *trie + response ResponseAction + TotalRules int32 + Hits int64 +} + +func New(action ResponseAction) *Blocklist { + return &Blocklist{ + blocked: newTrie(), + exceptions: newTrie(), + response: action, + } +} + +func (b *Blocklist) Response() ResponseAction { + return b.response +} + +func (b *Blocklist) IsBlocked(domain string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + labels := splitDomain(domain) + if !b.blocked.match(labels) { + return false + } + if b.exceptions.match(labels) { + return false + } + atomic.AddInt64(&b.Hits, 1) + return true +} + +func (b *Blocklist) LoadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' || line[0] == '!' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.StoreInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) LoadURL(url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.AddInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) addRule(line string) bool { + b.mu.Lock() + defer b.mu.Unlock() + + if strings.HasPrefix(line, "@@") { + domain := strings.TrimPrefix(line, "@@") + domain = strings.TrimPrefix(domain, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.exceptions.insert(splitDomain(domain)) + return true + } + + fields := strings.Fields(line) + if len(fields) >= 2 { + ip := fields[0] + if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" { + b.blocked.insert(splitDomain(fields[len(fields)-1])) + return true + } + } + if strings.HasPrefix(line, "||") { + domain := strings.TrimPrefix(line, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.blocked.insert(splitDomain(domain)) + return true + } + if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") { + b.blocked.insert(splitDomain(line)) + return true + } + return false +} + +func splitDomain(domain string) []string { + domain = strings.TrimSuffix(domain, ".") + return strings.Split(domain, ".") +} + +type trieNode struct { + children map[string]*trieNode + terminal bool +} + +type trie struct{ root *trieNode } + +func newTrie() *trie { + return &trie{ + root: &trieNode{ + children: make(map[string]*trieNode), + }, + } +} +func (t *trie) insert(labels []string) { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + child, ok := node.children[labels[i]] + if !ok { + child = &trieNode{children: make(map[string]*trieNode)} + node.children[labels[i]] = child + } + node = child + } + node.terminal = true +} + +func (t *trie) match(labels []string) bool { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + if node.terminal { + return true + } + child, ok := node.children[labels[i]] + if !ok { + return false + } + node = child + } + return node.terminal +} diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go new file mode 100644 index 0000000..9ae1749 --- /dev/null +++ b/internal/blocklist/blocklist_test.go @@ -0,0 +1,78 @@ +package blocklist + +import ( + "os" + "path/filepath" + "testing" +) + +func TestBlockZeroIP(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("0.0.0.0 doubleclick.net") + + cases := []struct { + domain string + block bool + }{ + {"example.com.", true}, + {"sub.example.com.", true}, + {"notexample.com.", false}, + {"doubleclick.net.", true}, + {"ads.doubleclick.net.", true}, + {"google.com.", false}, + } + for _, c := range cases { + got := b.IsBlocked(c.domain) + if got != c.block { + t.Errorf("IsBlocked(%q) = %v, want %v", c.domain, got, c.block) + } + } +} + +func TestException(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("@@||whitelist.example.com^") + + if !b.IsBlocked("example.com.") { + t.Error("expected blocked") + } + if b.IsBlocked("whitelist.example.com.") { + t.Error("expected NOT blocked (exception)") + } + if b.IsBlocked("sub.whitelist.example.com.") { + t.Error("expected NOT blocked (exception subdomain)") + } +} + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "block.txt") + os.WriteFile(p, []byte("0.0.0.0 ads.com\n||tracker.net^\n"), 0644) + + b := New(ResponseZeroIP) + if err := b.LoadFile(p); err != nil { + t.Fatal(err) + } + if !b.IsBlocked("ads.com.") { + t.Error("expected blocked") + } + if !b.IsBlocked("sub.tracker.net.") { + t.Error("expected blocked") + } + if b.IsBlocked("safe.org.") { + t.Error("expected NOT blocked") + } +} + +func TestResponseType(t *testing.T) { + b := New(ResponseNXDOMAIN) + if b.Response() != ResponseNXDOMAIN { + t.Error("expected NXDOMAIN") + } + b2 := New(ResponseZeroIP) + if b2.Response() != ResponseZeroIP { + t.Error("expected ZeroIP") + } +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a2d86a0 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,317 @@ +package cache + +import ( + "database/sql" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + _ "modernc.org/sqlite" +) + +type Key struct { + Name string + Qtype uint16 + Class uint16 +} + +type entry struct { + msg *dns.Msg + storedAt time.Time + ttl time.Duration +} + +func (e *entry) expired() bool { + return time.Since(e.storedAt) >= e.ttl +} + +func (e *entry) remaining() time.Duration { + d := e.ttl - time.Since(e.storedAt) + if d < 0 { + return 0 + } + return d +} + +type Cache struct { + mu sync.RWMutex + entries map[Key]*entry + maxSize int + db *sql.DB + + hits int64 + misses int64 + evicted int64 + stopCh chan struct{} +} + +func NewCache(maxSize int, dbPath string) (*Cache, error) { + if maxSize <= 0 { + maxSize = 100000 + } + c := &Cache{ + entries: make(map[Key]*entry), + maxSize: maxSize, + stopCh: make(chan struct{}), + } + + if dbPath != "" { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + c.db = db + + db.Exec("PRAGMA journal_mode=WAL") + db.Exec("PRAGMA synchronous=NORMAL") + db.Exec("PRAGMA cache_size=-65536") + + if _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS cache ( + name TEXT NOT NULL, + qtype INTEGER NOT NULL, + class INTEGER NOT NULL, + data BLOB NOT NULL, + stored_at INTEGER NOT NULL, + ttl_ns INTEGER NOT NULL, + PRIMARY KEY (name, qtype, class) + ) + `); err != nil { + db.Close() + return nil, err + } + + c.loadFromDB() + } + go c.evictLoop() + return c, nil +} + +func (c *Cache) Stop() { + close(c.stopCh) + if c.db != nil { + c.db.Close() + } +} + +func (c *Cache) Get(key Key) (*dns.Msg, bool) { + c.mu.RLock() + e, ok := c.entries[key] + if !ok || e.expired() { + c.mu.RUnlock() + atomic.AddInt64(&c.misses, 1) + return nil, false + } + c.mu.RUnlock() + + atomic.AddInt64(&c.hits, 1) + msg := e.msg.Copy() + remaining := e.remaining() + adjustTTL(msg, e.ttl, remaining) + return msg, true +} + +func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { + if ttl <= 0 { + ttl = computeTTL(msg) + } + if ttl <= 0 { + ttl = 60 * time.Second + } + + e := &entry{ + msg: msg.Copy(), + storedAt: time.Now(), + ttl: ttl, + } + + c.mu.Lock() + if len(c.entries) >= c.maxSize { + c.evictLocked() + } + c.entries[key] = e + c.mu.Unlock() + + if c.db != nil { + c.writeToDB(key, e) + } +} + +func (c *Cache) Stats() (int64, int64, int64) { + return atomic.LoadInt64(&c.hits), + atomic.LoadInt64(&c.misses), + atomic.LoadInt64(&c.evicted) +} + +func (c *Cache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.entries) +} + +func (c *Cache) evictLocked() { + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + if len(c.entries) < c.maxSize { + return + } + + var oldestKey Key + var oldestTime time.Time + first := true + for k, e := range c.entries { + if first || e.storedAt.Before(oldestTime) { + oldestKey = k + oldestTime = e.storedAt + first = false + } + } + if !first { + delete(c.entries, oldestKey) + atomic.AddInt64(&c.evicted, 1) + } +} + +func (c *Cache) evictLoop() { + tk := time.NewTicker(1 * time.Minute) + defer tk.Stop() + for { + select { + case <-tk.C: + c.mu.Lock() + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + c.mu.Unlock() + case <-c.stopCh: + return + } + } +} +func (c *Cache) writeToDB(key Key, e *entry) { + data, err := e.msg.Pack() + if err != nil { + return + } + c.db.Exec( + `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) + VALUES (?, ?, ?, ?, ?, ?)`, + key.Name, key.Qtype, key.Class, data, + e.storedAt.UnixNano(), int64(e.ttl), + ) +} + +func (c *Cache) loadFromDB() { + rows, err := c.db.Query( + `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`, + ) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var name string + var qtype, class uint16 + var data []byte + var storedAtNano, ttlNs int64 + + if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil { + continue + } + + msg := new(dns.Msg) + if err := msg.Unpack(data); err != nil { + continue + } + + e := &entry{ + msg: msg, + storedAt: time.Unix(0, storedAtNano), + ttl: time.Duration(ttlNs), + } + + if !e.expired() { + c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e + } else { + c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class) + } + } +} + + +func computeTTL(msg *dns.Msg) time.Duration { + if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + } + + var min uint32 + first := true + for _, rr := range msg.Answer { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Ns { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + if !first { + return time.Duration(min) * time.Second + } + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + return 0 +} + +func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { + ratio := float64(remaining) / float64(originalTTL) + for _, rr := range msg.Answer { + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Ns { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } +} + +func applyRatio(ttl uint32, ratio float64) uint32 { + return uint32(float64(ttl) * ratio) +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..6556dcb --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,140 @@ +package cache + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestSetGet(t *testing.T) { + c, err := NewCache(100, "") + if err != nil { + t.Fatal(err) + } + defer c.Stop() + + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300}, + A: net.IPv4(1,2,3,4), + }) + + key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 300*time.Second) + + got, ok := c.Get(key) + if !ok { + t.Fatal("expected cache hit") + } + if len(got.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(got.Answer)) + } + a, _ := got.Answer[0].(*dns.A) + if !a.A.Equal(net.IPv4(1,2,3,4)) { + t.Errorf("IP = %s, want 1.2.3.4", a.A) + } +} + +func TestExpiry(t *testing.T) { + c, _ := NewCache(10, "") + defer c.Stop() + msg := new(dns.Msg) + key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 30*time.Millisecond) + time.Sleep(60 * time.Millisecond) + _, ok := c.Get(key) + if ok { + t.Fatal("expected miss after expiry") + } +} + +func TestEviction(t *testing.T) { + c, _ := NewCache(2, "") + defer c.Stop() + for i := 0; i < 5; i++ { + name := fmt.Sprintf("d%d.com.", i) + msg := new(dns.Msg) + c.Set(Key{Name: name, Qtype: dns.TypeA, Class: dns.ClassINET}, msg, 60*time.Second) + } + if c.Len() > 2 { + t.Errorf("expected ≤2 entries, got %d", c.Len()) + } +} + +func TestComputeTTL(t *testing.T) { + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120}, + }) + if d := computeTTL(msg); d != 120*time.Second { + t.Errorf("TTL = %v, want 120s", d) + } +} + +func TestNegativeTTL(t *testing.T) { + msg := new(dns.Msg) + msg.Rcode = dns.RcodeNameError + msg.Ns = append(msg.Ns, &dns.SOA{ + Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900}, + Minttl: 300, + }) + if d := computeTTL(msg); d != 300*time.Second { + t.Errorf("negative TTL = %v, want 300s", d) + } +} + +func TestRace(t *testing.T) { + c, _ := NewCache(1000, "") + defer c.Stop() + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + msg := new(dns.Msg) + c.Set(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}, msg, time.Second) + } + close(done) + }() + for i := 0; i < 100; i++ { + c.Get(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}) + } + <-done + c.Stats() + c.Len() +} + +func TestSQLitePersistence(t *testing.T) { + dir := t.TempDir() + dbPath := dir + "/cache.db" + + c, err := NewCache(100, dbPath) + if err != nil { + t.Fatal(err) + } + + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300}, + A: net.IPv4(1, 2, 3, 4), + }) + key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 300*time.Second) + c.Stop() + + c2, err := NewCache(100, dbPath) + if err != nil { + t.Fatal(err) + } + defer c2.Stop() + + got, ok := c2.Get(key) + if !ok { + t.Fatal("expected cache hit from SQLite load") + } + a, _ := got.Answer[0].(*dns.A) + if !a.A.Equal(net.IPv4(1, 2, 3, 4)) { + t.Errorf("IP = %s, want 1.2.3.4", a.A) + } +} diff --git a/internal/server/doh.go b/internal/server/doh.go index 3f5a538..b46736e 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) packed, err := resp.Pack() if err != nil { http.Error(w, "pack response", http.StatusInternalServerError) diff --git a/internal/server/handler.go b/internal/server/handler.go index bac1c81..4aa771f 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,9 +2,13 @@ package server import ( "context" - "github.com/miekg/dns" "log/slog" + "net" "time" + + "github.com/miekg/dns" + "sdns/internal/blocklist" + "sdns/internal/cache" ) func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { @@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { return } - resp := s.buildResponse(req) + resp, blocked := s.buildResponse(req) if err := w.WriteMsg(resp); err != nil { slog.Error("write response failed", - "err", err, + "err", err, "qname", req.Question[0].Name, "qtype", dns.TypeToString[req.Question[0].Qtype], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], - "rcode", dns.RcodeToString[resp.Rcode], - "client", w.RemoteAddr().String(), + "qname", req.Question[0].Name, + "qtype", dns.TypeToString[req.Question[0].Qtype], + "rcode", dns.RcodeToString[resp.Rcode], + "client", w.RemoteAddr().String(), + "blocked", blocked, ) } -func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { +func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - return m + return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false } q := req.Question[0] - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = false - if opt := req.IsEdns0(); opt != nil { - resp.SetEdns0(4096, false) + if s.cache != nil { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + if cached, ok := s.cache.Get(key); ok { + return cached, false + } + } + + if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + return s.blockedResponse(req), true } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", q.Name, "qtype", dns.TypeToString[q.Qtype], ) - resp.Rcode = dns.RcodeServerFailure - return resp + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeServerFailure) + return m, false + } + + reply.Id = req.Id + + if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + s.cache.Set(key, reply, 0) + } + + return reply, false +} + +func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + + if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { + m.Rcode = dns.RcodeNameError + return m } - return reply + q := req.Question[0] + switch q.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.IPv4(0, 0, 0, 0), + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, + AAAA: net.IPv6zero, + }) + } + return m } diff --git a/internal/server/server.go b/internal/server/server.go index 3114073..e0490bd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,18 +8,23 @@ import ( "github.com/miekg/dns" "sdns/internal/resolver" + "sdns/internal/blocklist" + "sdns/internal/cache" ) type Server struct { logger *slog.Logger resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist udp *dns.Server tcp *dns.Server doh *http.Server } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) { - s := &Server{logger: logger, resolver: r} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, +r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { + s := &Server{logger: logger, resolver: r, cache: c, blocklist: b} mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 6fc2092..c49d5f3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -40,7 +40,7 @@ func TestBuildResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := s.buildResponse(tt.req) + resp, _ := s.buildResponse(tt.req) if resp.Rcode != tt.wantRcode { t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode) } @@ -63,7 +63,7 @@ func TestBuildResponseWithQuery(t *testing.T) { // Valid query → must not panic, rcode must be valid m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) - resp := s.buildResponse(m) + resp, _ := s.buildResponse(m) if resp == nil { t.Fatal("buildResponse returned nil") } @@ -100,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) { if err := msg.Unpack(data); err != nil { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } @@ -5,8 +5,11 @@ import ( "log/slog" "os" "os/signal" + "path/filepath" "syscall" + "sdns/internal/blocklist" + "sdns/internal/cache" "sdns/internal/resolver" "sdns/internal/server" ) @@ -19,6 +22,33 @@ func main() { r := resolver.New() + dbPath := os.Getenv("SDNS_CACHE_DB") + if dbPath != "" { + logger.Info("cache using sqlite", "path", dbPath) + } else { + logger.Info("cache using in-memory") + } + c, err := cache.NewCache(100000, dbPath) + if err != nil { + logger.Error("create cache failed", "err", err) + os.Exit(1) + } + defer c.Stop() + + var bl *blocklist.Blocklist + matches, _ := filepath.Glob("etc/blocklist/*.txt") + if len(matches) > 0 { + bl = blocklist.New(blocklist.ResponseZeroIP) + for _, f := range matches { + if err := bl.LoadFile(f); err != nil { + logger.Error("load blocklist failed", "file", f, "err", err) + os.Exit(1) + } + logger.Info("blocklist loaded", "file", f, "rules", bl.TotalRules) + } + } else { + logger.Info("no blocklist files in etc/blocklist/, ad-blocking disabled") + } udp := os.Getenv("SDNS_LISTEN_UDP") if udp == "" { udp = ":5353" @@ -37,7 +67,7 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - srv, err := server.New(udp, tcp, doh, logger, r) + srv, err := server.New(udp, tcp, doh, logger, r, c, bl) if err != nil { logger.Error("create server failed", "err", err) os.Exit(1) |
