summaryrefslogtreecommitdiff
path: root/internal/server/handler.go
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
committerradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
commitb7359e1d45f505171356bcae3c7d5e2341ecc859 (patch)
treef91d4a4b08ce279d488a76e9b7141e69fc844ea9 /internal/server/handler.go
parent2b1f613c42de3861141eb6f93c1740b6937ee183 (diff)
forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply
Diffstat (limited to 'internal/server/handler.go')
-rw-r--r--internal/server/handler.go121
1 files changed, 94 insertions, 27 deletions
diff --git a/internal/server/handler.go b/internal/server/handler.go
index d0aa705..4516468 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -4,6 +4,7 @@ import (
"context"
"io"
"log/slog"
+ "net"
"net/netip"
"time"
@@ -14,6 +15,18 @@ import (
)
func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ clientIP := parseClientIP(w.RemoteAddr())
+ if !s.isAllowed(clientIP) {
+ slog.Warn("query denied by ACL", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+ if !s.rateLimit(clientIP.String()) {
+ slog.Warn("query rate limited", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+
if len(req.Question) == 0 {
m := new(dns.Msg)
m.Rcode = dns.RcodeFormatError
@@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns
io.Copy(w, m)
return
}
+ q := req.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
- resp, blocked := s.buildResponse(req)
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
+ resp := s.blockedResponse(req)
+ resp.ID = req.ID
+ io.Copy(w, resp)
+ slog.Info("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode],
+ "blocked", true)
+ return
+ }
+ if s.cache != nil {
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
+ if packed, ok := s.cache.Get(key); ok {
+ packed[0] = byte(req.ID >> 8)
+ packed[1] = byte(req.ID)
+ io.Copy(w, &dns.Msg{Data: packed})
+ slog.Debug("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "cached", true)
+ return
+ }
+ }
+ resp, blocked := s.buildResponse(req)
resp.ID = req.ID
resp.Data = nil
if _, err := io.Copy(w, resp); err != nil {
- slog.Error("write response failed",
- "err", err,
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- )
+ slog.Error("write failed", "err", err, "qname", qname)
return
}
- slog.Info("query served",
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- "rcode", dns.RcodeToString[resp.Rcode],
- "blocked", blocked,
- )
+ slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype],
+ "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked)
}
-
func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
m := new(dns.Msg)
@@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
m.Question = req.Question
return m, false
}
+
q := req.Question[0]
qname := q.Header().Name
qtype := dns.RRToType(q)
qclass := q.Header().Class
- if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
- return s.blockedResponse(req), true
- }
-
- if s.cache != nil {
- key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
- if cached, ok := s.cache.Get(key); ok {
- cached.ID = req.ID
- return cached, false
- }
- }
-
ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
reply, err := s.resolver.Lookup(ctx, qname, qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", qname,
"qtype", dns.TypeToString[qtype],
)
@@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
})
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()},
})
}
return m
}
+
+func parseClientIP(addr net.Addr) net.IP {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.IP
+ case *net.TCPAddr:
+ return a.IP
+ }
+ host, _, _ := net.SplitHostPort(addr.String())
+ return net.ParseIP(host)
+}
+
+func parseHTTPClientIP(addr string) net.IP {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return net.ParseIP(addr)
+ }
+ return net.ParseIP(host)
+}
+
+func (s *Server) isAllowed(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ for _, n := range s.aclNets {
+ if n.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *Server) rateLimit(ip string) bool {
+ if s.rateLimiter == nil {
+ return true
+ }
+ return s.rateLimiter.allow(ip)
+}
+
+func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) {
+ resp := new(dns.Msg)
+ resp.Response = true
+ resp.ID = req.ID
+ resp.Question = req.Question
+ resp.Rcode = dns.RcodeRefused
+ io.Copy(w, resp)
+}