summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
committerradhitya <alif@radhitya.org>2026-06-14 14:36:32 +0700
commit4e6a897a0b55ee533c05f89fa38dbe0704f2798d (patch)
tree12d9700e53775503ad7ba2beb72bedfc64bdd70d /internal/server
parent3e44adc94f32bfe500730fcbf1c02cedf65b0a30 (diff)
dns recursive resolver(iterative, root hints, delegfation, glue, fallback), adblocker, dns cache
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/doh.go2
-rw-r--r--internal/server/handler.go81
-rw-r--r--internal/server/server.go9
-rw-r--r--internal/server/server_test.go6
4 files changed, 72 insertions, 26 deletions
diff --git a/internal/server/doh.go b/internal/server/doh.go
index 3f5a538..b46736e 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
return
}
- resp := s.buildResponse(msg)
+ resp, _ := s.buildResponse(msg)
packed, err := resp.Pack()
if err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
diff --git a/internal/server/handler.go b/internal/server/handler.go
index bac1c81..4aa771f 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -2,9 +2,13 @@ package server
import (
"context"
- "github.com/miekg/dns"
"log/slog"
+ "net"
"time"
+
+ "github.com/miekg/dns"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
)
func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
@@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
return
}
- resp := s.buildResponse(req)
+ resp, blocked := s.buildResponse(req)
if err := w.WriteMsg(resp); err != nil {
slog.Error("write response failed",
- "err", err,
+ "err", err,
"qname", req.Question[0].Name,
"qtype", dns.TypeToString[req.Question[0].Qtype],
)
return
}
slog.Info("query served",
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
- "rcode", dns.RcodeToString[resp.Rcode],
- "client", w.RemoteAddr().String(),
+ "qname", req.Question[0].Name,
+ "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ "client", w.RemoteAddr().String(),
+ "blocked", blocked,
)
}
-func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
+func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
- m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- return m
+ return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false
}
q := req.Question[0]
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Authoritative = false
- if opt := req.IsEdns0(); opt != nil {
- resp.SetEdns0(4096, false)
+ if s.cache != nil {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ if cached, ok := s.cache.Get(key); ok {
+ return cached, false
+ }
+ }
+
+ if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) {
+ return s.blockedResponse(req), true
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", q.Name,
"qtype", dns.TypeToString[q.Qtype],
)
- resp.Rcode = dns.RcodeServerFailure
- return resp
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeServerFailure)
+ return m, false
+ }
+
+ reply.Id = req.Id
+
+ if s.cache != nil && reply.Rcode != dns.RcodeServerFailure {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ s.cache.Set(key, reply, 0)
+ }
+
+ return reply, false
+}
+
+func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ m.Authoritative = true
+
+ if s.blocklist.Response() == blocklist.ResponseNXDOMAIN {
+ m.Rcode = dns.RcodeNameError
+ return m
}
- return reply
+ q := req.Question[0]
+ switch q.Qtype {
+ case dns.TypeA:
+ m.Answer = append(m.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
+ A: net.IPv4(0, 0, 0, 0),
+ })
+ case dns.TypeAAAA:
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
+ AAAA: net.IPv6zero,
+ })
+ }
+ return m
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 3114073..e0490bd 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -8,18 +8,23 @@ import (
"github.com/miekg/dns"
"sdns/internal/resolver"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
)
type Server struct {
logger *slog.Logger
resolver *resolver.Resolver
+ cache *cache.Cache
+ blocklist *blocklist.Blocklist
udp *dns.Server
tcp *dns.Server
doh *http.Server
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) {
- s := &Server{logger: logger, resolver: r}
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
+r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
+ s := &Server{logger: logger, resolver: r, cache: c, blocklist: b}
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 6fc2092..c49d5f3 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -40,7 +40,7 @@ func TestBuildResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- resp := s.buildResponse(tt.req)
+ resp, _ := s.buildResponse(tt.req)
if resp.Rcode != tt.wantRcode {
t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode)
}
@@ -63,7 +63,7 @@ func TestBuildResponseWithQuery(t *testing.T) {
// Valid query → must not panic, rcode must be valid
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
- resp := s.buildResponse(m)
+ resp, _ := s.buildResponse(m)
if resp == nil {
t.Fatal("buildResponse returned nil")
}
@@ -100,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) {
if err := msg.Unpack(data); err != nil {
return
}
- resp := s.buildResponse(msg)
+ resp, _ := s.buildResponse(msg)
if resp == nil {
t.Fatal("buildResponse returned nil")
}