summaryrefslogtreecommitdiff
path: root/internal/server/handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/handler.go')
-rw-r--r--internal/server/handler.go81
1 files changed, 61 insertions, 20 deletions
diff --git a/internal/server/handler.go b/internal/server/handler.go
index bac1c81..4aa771f 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -2,9 +2,13 @@ package server
import (
"context"
- "github.com/miekg/dns"
"log/slog"
+ "net"
"time"
+
+ "github.com/miekg/dns"
+ "sdns/internal/blocklist"
+ "sdns/internal/cache"
)
func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
@@ -15,37 +19,40 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
return
}
- resp := s.buildResponse(req)
+ resp, blocked := s.buildResponse(req)
if err := w.WriteMsg(resp); err != nil {
slog.Error("write response failed",
- "err", err,
+ "err", err,
"qname", req.Question[0].Name,
"qtype", dns.TypeToString[req.Question[0].Qtype],
)
return
}
slog.Info("query served",
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
- "rcode", dns.RcodeToString[resp.Rcode],
- "client", w.RemoteAddr().String(),
+ "qname", req.Question[0].Name,
+ "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ "client", w.RemoteAddr().String(),
+ "blocked", blocked,
)
}
-func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
+func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
- m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- return m
+ return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false
}
q := req.Question[0]
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Authoritative = false
- if opt := req.IsEdns0(); opt != nil {
- resp.SetEdns0(4096, false)
+ if s.cache != nil {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ if cached, ok := s.cache.Get(key); ok {
+ return cached, false
+ }
+ }
+
+ if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) {
+ return s.blockedResponse(req), true
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -54,13 +61,47 @@ func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", q.Name,
"qtype", dns.TypeToString[q.Qtype],
)
- resp.Rcode = dns.RcodeServerFailure
- return resp
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeServerFailure)
+ return m, false
+ }
+
+ reply.Id = req.Id
+
+ if s.cache != nil && reply.Rcode != dns.RcodeServerFailure {
+ key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ s.cache.Set(key, reply, 0)
+ }
+
+ return reply, false
+}
+
+func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetReply(req)
+ m.Authoritative = true
+
+ if s.blocklist.Response() == blocklist.ResponseNXDOMAIN {
+ m.Rcode = dns.RcodeNameError
+ return m
}
- return reply
+ q := req.Question[0]
+ switch q.Qtype {
+ case dns.TypeA:
+ m.Answer = append(m.Answer, &dns.A{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
+ A: net.IPv4(0, 0, 0, 0),
+ })
+ case dns.TypeAAAA:
+ m.Answer = append(m.Answer, &dns.AAAA{
+ Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
+ AAAA: net.IPv6zero,
+ })
+ }
+ return m
}