diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/blocklist/blocklist.go | 181 | ||||
| -rw-r--r-- | internal/blocklist/blocklist_test.go | 78 | ||||
| -rw-r--r-- | internal/cache/cache.go | 317 | ||||
| -rw-r--r-- | internal/cache/cache_test.go | 140 | ||||
| -rw-r--r-- | internal/server/doh.go | 2 | ||||
| -rw-r--r-- | internal/server/handler.go | 81 | ||||
| -rw-r--r-- | internal/server/server.go | 9 | ||||
| -rw-r--r-- | internal/server/server_test.go | 6 |
8 files changed, 788 insertions, 26 deletions
diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go new file mode 100644 index 0000000..ae6d20f --- /dev/null +++ b/internal/blocklist/blocklist.go @@ -0,0 +1,181 @@ +package blocklist + +import ( + "bufio" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" +) + +type ResponseAction int + +const ( + ResponseZeroIP ResponseAction = iota + ResponseNXDOMAIN +) + +type Blocklist struct { + mu sync.RWMutex + blocked *trie + exceptions *trie + response ResponseAction + TotalRules int32 + Hits int64 +} + +func New(action ResponseAction) *Blocklist { + return &Blocklist{ + blocked: newTrie(), + exceptions: newTrie(), + response: action, + } +} + +func (b *Blocklist) Response() ResponseAction { + return b.response +} + +func (b *Blocklist) IsBlocked(domain string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + + labels := splitDomain(domain) + if !b.blocked.match(labels) { + return false + } + if b.exceptions.match(labels) { + return false + } + atomic.AddInt64(&b.Hits, 1) + return true +} + +func (b *Blocklist) LoadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' || line[0] == '!' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.StoreInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) LoadURL(url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + var n int32 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' { + continue + } + if b.addRule(line) { + n++ + } + } + atomic.AddInt32(&b.TotalRules, n) + return scanner.Err() +} + +func (b *Blocklist) addRule(line string) bool { + b.mu.Lock() + defer b.mu.Unlock() + + if strings.HasPrefix(line, "@@") { + domain := strings.TrimPrefix(line, "@@") + domain = strings.TrimPrefix(domain, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.exceptions.insert(splitDomain(domain)) + return true + } + + fields := strings.Fields(line) + if len(fields) >= 2 { + ip := fields[0] + if ip == "0.0.0.0" || ip == "127.0.0.1" || ip == "::1" || ip == "::" { + b.blocked.insert(splitDomain(fields[len(fields)-1])) + return true + } + } + if strings.HasPrefix(line, "||") { + domain := strings.TrimPrefix(line, "||") + if idx := strings.Index(domain, "^"); idx > 0 { + domain = domain[:idx] + } + b.blocked.insert(splitDomain(domain)) + return true + } + if strings.Contains(line, ".") && !strings.ContainsAny(line, " /") { + b.blocked.insert(splitDomain(line)) + return true + } + return false +} + +func splitDomain(domain string) []string { + domain = strings.TrimSuffix(domain, ".") + return strings.Split(domain, ".") +} + +type trieNode struct { + children map[string]*trieNode + terminal bool +} + +type trie struct{ root *trieNode } + +func newTrie() *trie { + return &trie{ + root: &trieNode{ + children: make(map[string]*trieNode), + }, + } +} +func (t *trie) insert(labels []string) { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + child, ok := node.children[labels[i]] + if !ok { + child = &trieNode{children: make(map[string]*trieNode)} + node.children[labels[i]] = child + } + node = child + } + node.terminal = true +} + +func (t *trie) match(labels []string) bool { + node := t.root + for i := len(labels) - 1; i >= 0; i-- { + if node.terminal { + return true + } + child, ok := node.children[labels[i]] + if !ok { + return false + } + node = child + } + return node.terminal +} diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go new file mode 100644 index 0000000..9ae1749 --- /dev/null +++ b/internal/blocklist/blocklist_test.go @@ -0,0 +1,78 @@ +package blocklist + +import ( + "os" + "path/filepath" + "testing" +) + +func TestBlockZeroIP(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("0.0.0.0 doubleclick.net") + + cases := []struct { + domain string + block bool + }{ + {"example.com.", true}, + {"sub.example.com.", true}, + {"notexample.com.", false}, + {"doubleclick.net.", true}, + {"ads.doubleclick.net.", true}, + {"google.com.", false}, + } + for _, c := range cases { + got := b.IsBlocked(c.domain) + if got != c.block { + t.Errorf("IsBlocked(%q) = %v, want %v", c.domain, got, c.block) + } + } +} + +func TestException(t *testing.T) { + b := New(ResponseZeroIP) + b.addRule("||example.com^") + b.addRule("@@||whitelist.example.com^") + + if !b.IsBlocked("example.com.") { + t.Error("expected blocked") + } + if b.IsBlocked("whitelist.example.com.") { + t.Error("expected NOT blocked (exception)") + } + if b.IsBlocked("sub.whitelist.example.com.") { + t.Error("expected NOT blocked (exception subdomain)") + } +} + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "block.txt") + os.WriteFile(p, []byte("0.0.0.0 ads.com\n||tracker.net^\n"), 0644) + + b := New(ResponseZeroIP) + if err := b.LoadFile(p); err != nil { + t.Fatal(err) + } + if !b.IsBlocked("ads.com.") { + t.Error("expected blocked") + } + if !b.IsBlocked("sub.tracker.net.") { + t.Error("expected blocked") + } + if b.IsBlocked("safe.org.") { + t.Error("expected NOT blocked") + } +} + +func TestResponseType(t *testing.T) { + b := New(ResponseNXDOMAIN) + if b.Response() != ResponseNXDOMAIN { + t.Error("expected NXDOMAIN") + } + b2 := New(ResponseZeroIP) + if b2.Response() != ResponseZeroIP { + t.Error("expected ZeroIP") + } +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a2d86a0 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,317 @@ +package cache + +import ( + "database/sql" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + _ "modernc.org/sqlite" +) + +type Key struct { + Name string + Qtype uint16 + Class uint16 +} + +type entry struct { + msg *dns.Msg + storedAt time.Time + ttl time.Duration +} + +func (e *entry) expired() bool { + return time.Since(e.storedAt) >= e.ttl +} + +func (e *entry) remaining() time.Duration { + d := e.ttl - time.Since(e.storedAt) + if d < 0 { + return 0 + } + return d +} + +type Cache struct { + mu sync.RWMutex + entries map[Key]*entry + maxSize int + db *sql.DB + + hits int64 + misses int64 + evicted int64 + stopCh chan struct{} +} + +func NewCache(maxSize int, dbPath string) (*Cache, error) { + if maxSize <= 0 { + maxSize = 100000 + } + c := &Cache{ + entries: make(map[Key]*entry), + maxSize: maxSize, + stopCh: make(chan struct{}), + } + + if dbPath != "" { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + c.db = db + + db.Exec("PRAGMA journal_mode=WAL") + db.Exec("PRAGMA synchronous=NORMAL") + db.Exec("PRAGMA cache_size=-65536") + + if _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS cache ( + name TEXT NOT NULL, + qtype INTEGER NOT NULL, + class INTEGER NOT NULL, + data BLOB NOT NULL, + stored_at INTEGER NOT NULL, + ttl_ns INTEGER NOT NULL, + PRIMARY KEY (name, qtype, class) + ) + `); err != nil { + db.Close() + return nil, err + } + + c.loadFromDB() + } + go c.evictLoop() + return c, nil +} + +func (c *Cache) Stop() { + close(c.stopCh) + if c.db != nil { + c.db.Close() + } +} + +func (c *Cache) Get(key Key) (*dns.Msg, bool) { + c.mu.RLock() + e, ok := c.entries[key] + if !ok || e.expired() { + c.mu.RUnlock() + atomic.AddInt64(&c.misses, 1) + return nil, false + } + c.mu.RUnlock() + + atomic.AddInt64(&c.hits, 1) + msg := e.msg.Copy() + remaining := e.remaining() + adjustTTL(msg, e.ttl, remaining) + return msg, true +} + +func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { + if ttl <= 0 { + ttl = computeTTL(msg) + } + if ttl <= 0 { + ttl = 60 * time.Second + } + + e := &entry{ + msg: msg.Copy(), + storedAt: time.Now(), + ttl: ttl, + } + + c.mu.Lock() + if len(c.entries) >= c.maxSize { + c.evictLocked() + } + c.entries[key] = e + c.mu.Unlock() + + if c.db != nil { + c.writeToDB(key, e) + } +} + +func (c *Cache) Stats() (int64, int64, int64) { + return atomic.LoadInt64(&c.hits), + atomic.LoadInt64(&c.misses), + atomic.LoadInt64(&c.evicted) +} + +func (c *Cache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.entries) +} + +func (c *Cache) evictLocked() { + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + if len(c.entries) < c.maxSize { + return + } + + var oldestKey Key + var oldestTime time.Time + first := true + for k, e := range c.entries { + if first || e.storedAt.Before(oldestTime) { + oldestKey = k + oldestTime = e.storedAt + first = false + } + } + if !first { + delete(c.entries, oldestKey) + atomic.AddInt64(&c.evicted, 1) + } +} + +func (c *Cache) evictLoop() { + tk := time.NewTicker(1 * time.Minute) + defer tk.Stop() + for { + select { + case <-tk.C: + c.mu.Lock() + now := time.Now() + for k, e := range c.entries { + if now.After(e.storedAt.Add(e.ttl)) { + delete(c.entries, k) + atomic.AddInt64(&c.evicted, 1) + } + } + c.mu.Unlock() + case <-c.stopCh: + return + } + } +} +func (c *Cache) writeToDB(key Key, e *entry) { + data, err := e.msg.Pack() + if err != nil { + return + } + c.db.Exec( + `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) + VALUES (?, ?, ?, ?, ?, ?)`, + key.Name, key.Qtype, key.Class, data, + e.storedAt.UnixNano(), int64(e.ttl), + ) +} + +func (c *Cache) loadFromDB() { + rows, err := c.db.Query( + `SELECT name, qtype, class, data, stored_at, ttl_ns FROM cache`, + ) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var name string + var qtype, class uint16 + var data []byte + var storedAtNano, ttlNs int64 + + if err := rows.Scan(&name, &qtype, &class, &data, &storedAtNano, &ttlNs); err != nil { + continue + } + + msg := new(dns.Msg) + if err := msg.Unpack(data); err != nil { + continue + } + + e := &entry{ + msg: msg, + storedAt: time.Unix(0, storedAtNano), + ttl: time.Duration(ttlNs), + } + + if !e.expired() { + c.entries[Key{Name: name, Qtype: qtype, Class: class}] = e + } else { + c.db.Exec(`DELETE FROM cache WHERE name=? AND qtype=? AND class=?`, name, qtype, class) + } + } +} + + +func computeTTL(msg *dns.Msg) time.Duration { + if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + } + + var min uint32 + first := true + for _, rr := range msg.Answer { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Ns { + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + if first || rr.Header().Ttl < min { + min = rr.Header().Ttl + first = false + } + } + if !first { + return time.Duration(min) * time.Second + } + for _, rr := range msg.Ns { + if soa, ok := rr.(*dns.SOA); ok { + return time.Duration(soa.Minttl) * time.Second + } + } + return 0 +} + +func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { + ratio := float64(remaining) / float64(originalTTL) + for _, rr := range msg.Answer { + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Ns { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } + for _, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + continue + } + rr.Header().Ttl = applyRatio(rr.Header().Ttl, ratio) + } +} + +func applyRatio(ttl uint32, ratio float64) uint32 { + return uint32(float64(ttl) * ratio) +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..6556dcb --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,140 @@ +package cache + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestSetGet(t *testing.T) { + c, err := NewCache(100, "") + if err != nil { + t.Fatal(err) + } + defer c.Stop() + + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Ttl:300}, + A: net.IPv4(1,2,3,4), + }) + + key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 300*time.Second) + + got, ok := c.Get(key) + if !ok { + t.Fatal("expected cache hit") + } + if len(got.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(got.Answer)) + } + a, _ := got.Answer[0].(*dns.A) + if !a.A.Equal(net.IPv4(1,2,3,4)) { + t.Errorf("IP = %s, want 1.2.3.4", a.A) + } +} + +func TestExpiry(t *testing.T) { + c, _ := NewCache(10, "") + defer c.Stop() + msg := new(dns.Msg) + key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 30*time.Millisecond) + time.Sleep(60 * time.Millisecond) + _, ok := c.Get(key) + if ok { + t.Fatal("expected miss after expiry") + } +} + +func TestEviction(t *testing.T) { + c, _ := NewCache(2, "") + defer c.Stop() + for i := 0; i < 5; i++ { + name := fmt.Sprintf("d%d.com.", i) + msg := new(dns.Msg) + c.Set(Key{Name: name, Qtype: dns.TypeA, Class: dns.ClassINET}, msg, 60*time.Second) + } + if c.Len() > 2 { + t.Errorf("expected ≤2 entries, got %d", c.Len()) + } +} + +func TestComputeTTL(t *testing.T) { + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA, Ttl: 120}, + }) + if d := computeTTL(msg); d != 120*time.Second { + t.Errorf("TTL = %v, want 120s", d) + } +} + +func TestNegativeTTL(t *testing.T) { + msg := new(dns.Msg) + msg.Rcode = dns.RcodeNameError + msg.Ns = append(msg.Ns, &dns.SOA{ + Hdr: dns.RR_Header{Name: "com.", Rrtype: dns.TypeSOA, Ttl: 900}, + Minttl: 300, + }) + if d := computeTTL(msg); d != 300*time.Second { + t.Errorf("negative TTL = %v, want 300s", d) + } +} + +func TestRace(t *testing.T) { + c, _ := NewCache(1000, "") + defer c.Stop() + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + msg := new(dns.Msg) + c.Set(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}, msg, time.Second) + } + close(done) + }() + for i := 0; i < 100; i++ { + c.Get(Key{Name: fmt.Sprintf("d%d.com.", i), Qtype: dns.TypeA, Class: dns.ClassINET}) + } + <-done + c.Stats() + c.Len() +} + +func TestSQLitePersistence(t *testing.T) { + dir := t.TempDir() + dbPath := dir + "/cache.db" + + c, err := NewCache(100, dbPath) + if err != nil { + t.Fatal(err) + } + + msg := new(dns.Msg) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: "x.com.", Rrtype: dns.TypeA, Ttl: 300}, + A: net.IPv4(1, 2, 3, 4), + }) + key := Key{Name: "x.com.", Qtype: dns.TypeA, Class: dns.ClassINET} + c.Set(key, msg, 300*time.Second) + c.Stop() + + c2, err := NewCache(100, dbPath) + if err != nil { + t.Fatal(err) + } + defer c2.Stop() + + got, ok := c2.Get(key) + if !ok { + t.Fatal("expected cache hit from SQLite load") + } + a, _ := got.Answer[0].(*dns.A) + if !a.A.Equal(net.IPv4(1, 2, 3, 4)) { + t.Errorf("IP = %s, want 1.2.3.4", a.A) + } +} diff --git a/internal/server/doh.go b/internal/server/doh.go index 3f5a538..b46736e 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) packed, err := resp.Pack() if err != nil { http.Error(w, "pack response", http.StatusInternalServerError) diff --git a/internal/server/handler.go b/internal/server/handler.go index bac1c81..4aa771f 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,9 +2,13 @@ package server import ( "context" - "github.com/miekg/dns" "log/slog" + "net" "time" + + "github.com/miekg/dns" + "sdns/internal/blocklist" + "sdns/internal/cache" ) func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { @@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { return } - resp := s.buildResponse(req) + resp, blocked := s.buildResponse(req) if err := w.WriteMsg(resp); err != nil { slog.Error("write response failed", - "err", err, + "err", err, "qname", req.Question[0].Name, "qtype", dns.TypeToString[req.Question[0].Qtype], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], - "rcode", dns.RcodeToString[resp.Rcode], - "client", w.RemoteAddr().String(), + "qname", req.Question[0].Name, + "qtype", dns.TypeToString[req.Question[0].Qtype], + "rcode", dns.RcodeToString[resp.Rcode], + "client", w.RemoteAddr().String(), + "blocked", blocked, ) } -func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { +func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - return m + return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false } q := req.Question[0] - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = false - if opt := req.IsEdns0(); opt != nil { - resp.SetEdns0(4096, false) + if s.cache != nil { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + if cached, ok := s.cache.Get(key); ok { + return cached, false + } + } + + if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + return s.blockedResponse(req), true } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", q.Name, "qtype", dns.TypeToString[q.Qtype], ) - resp.Rcode = dns.RcodeServerFailure - return resp + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeServerFailure) + return m, false + } + + reply.Id = req.Id + + if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + s.cache.Set(key, reply, 0) + } + + return reply, false +} + +func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + + if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { + m.Rcode = dns.RcodeNameError + return m } - return reply + q := req.Question[0] + switch q.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.IPv4(0, 0, 0, 0), + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, + AAAA: net.IPv6zero, + }) + } + return m } diff --git a/internal/server/server.go b/internal/server/server.go index 3114073..e0490bd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,18 +8,23 @@ import ( "github.com/miekg/dns" "sdns/internal/resolver" + "sdns/internal/blocklist" + "sdns/internal/cache" ) type Server struct { logger *slog.Logger resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist udp *dns.Server tcp *dns.Server doh *http.Server } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) { - s := &Server{logger: logger, resolver: r} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, +r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { + s := &Server{logger: logger, resolver: r, cache: c, blocklist: b} mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 6fc2092..c49d5f3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -40,7 +40,7 @@ func TestBuildResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := s.buildResponse(tt.req) + resp, _ := s.buildResponse(tt.req) if resp.Rcode != tt.wantRcode { t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode) } @@ -63,7 +63,7 @@ func TestBuildResponseWithQuery(t *testing.T) { // Valid query → must not panic, rcode must be valid m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) - resp := s.buildResponse(m) + resp, _ := s.buildResponse(m) if resp == nil { t.Fatal("buildResponse returned nil") } @@ -100,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) { if err := msg.Unpack(data); err != nil { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } |
