From 4e6a897a0b55ee533c05f89fa38dbe0704f2798d Mon Sep 17 00:00:00 2001 From: radhitya Date: Sun, 14 Jun 2026 14:36:32 +0700 Subject: dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache --- internal/server/doh.go | 2 +- internal/server/handler.go | 81 +++++++++++++++++++++++++++++++----------- internal/server/server.go | 9 +++-- internal/server/server_test.go | 6 ++-- 4 files changed, 72 insertions(+), 26 deletions(-) (limited to 'internal/server') diff --git a/internal/server/doh.go b/internal/server/doh.go index 3f5a538..b46736e 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) packed, err := resp.Pack() if err != nil { http.Error(w, "pack response", http.StatusInternalServerError) diff --git a/internal/server/handler.go b/internal/server/handler.go index bac1c81..4aa771f 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,9 +2,13 @@ package server import ( "context" - "github.com/miekg/dns" "log/slog" + "net" "time" + + "github.com/miekg/dns" + "sdns/internal/blocklist" + "sdns/internal/cache" ) func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { @@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { return } - resp := s.buildResponse(req) + resp, blocked := s.buildResponse(req) if err := w.WriteMsg(resp); err != nil { slog.Error("write response failed", - "err", err, + "err", err, "qname", req.Question[0].Name, "qtype", dns.TypeToString[req.Question[0].Qtype], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], - "rcode", dns.RcodeToString[resp.Rcode], - "client", w.RemoteAddr().String(), + "qname", req.Question[0].Name, + "qtype", dns.TypeToString[req.Question[0].Qtype], + "rcode", dns.RcodeToString[resp.Rcode], + "client", w.RemoteAddr().String(), + "blocked", blocked, ) } -func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { +func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - return m + return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false } q := req.Question[0] - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = false - if opt := req.IsEdns0(); opt != nil { - resp.SetEdns0(4096, false) + if s.cache != nil { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + if cached, ok := s.cache.Get(key); ok { + return cached, false + } + } + + if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + return s.blockedResponse(req), true } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg { reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", q.Name, "qtype", dns.TypeToString[q.Qtype], ) - resp.Rcode = dns.RcodeServerFailure - return resp + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeServerFailure) + return m, false + } + + reply.Id = req.Id + + if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { + key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + s.cache.Set(key, reply, 0) + } + + return reply, false +} + +func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + + if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { + m.Rcode = dns.RcodeNameError + return m } - return reply + q := req.Question[0] + switch q.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.IPv4(0, 0, 0, 0), + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, + AAAA: net.IPv6zero, + }) + } + return m } diff --git a/internal/server/server.go b/internal/server/server.go index 3114073..e0490bd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,18 +8,23 @@ import ( "github.com/miekg/dns" "sdns/internal/resolver" + "sdns/internal/blocklist" + "sdns/internal/cache" ) type Server struct { logger *slog.Logger resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist udp *dns.Server tcp *dns.Server doh *http.Server } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) { - s := &Server{logger: logger, resolver: r} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, +r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { + s := &Server{logger: logger, resolver: r, cache: c, blocklist: b} mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 6fc2092..c49d5f3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -40,7 +40,7 @@ func TestBuildResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := s.buildResponse(tt.req) + resp, _ := s.buildResponse(tt.req) if resp.Rcode != tt.wantRcode { t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode) } @@ -63,7 +63,7 @@ func TestBuildResponseWithQuery(t *testing.T) { // Valid query → must not panic, rcode must be valid m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) - resp := s.buildResponse(m) + resp, _ := s.buildResponse(m) if resp == nil { t.Fatal("buildResponse returned nil") } @@ -100,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) { if err := msg.Unpack(data); err != nil { return } - resp := s.buildResponse(msg) + resp, _ := s.buildResponse(msg) if resp == nil { t.Fatal("buildResponse returned nil") } -- cgit v1.2.3