diff options
Diffstat (limited to 'internal/server/handler.go')
| -rw-r--r-- | internal/server/handler.go | 121 |
1 files changed, 94 insertions, 27 deletions
diff --git a/internal/server/handler.go b/internal/server/handler.go index d0aa705..4516468 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "net" "net/netip" "time" @@ -14,6 +15,18 @@ import ( ) func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + clientIP := parseClientIP(w.RemoteAddr()) + if !s.isAllowed(clientIP) { + slog.Warn("query denied by ACL", "client", clientIP) + s.writeRefused(w, req) + return + } + if !s.rateLimit(clientIP.String()) { + slog.Warn("query rate limited", "client", clientIP) + s.writeRefused(w, req) + return + } + if len(req.Question) == 0 { m := new(dns.Msg) m.Rcode = dns.RcodeFormatError @@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns io.Copy(w, m) return } + q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - resp, blocked := s.buildResponse(req) + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { + resp := s.blockedResponse(req) + resp.ID = req.ID + io.Copy(w, resp) + slog.Info("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], + "blocked", true) + return + } + if s.cache != nil { + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} + if packed, ok := s.cache.Get(key); ok { + packed[0] = byte(req.ID >> 8) + packed[1] = byte(req.ID) + io.Copy(w, &dns.Msg{Data: packed}) + slog.Debug("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "cached", true) + return + } + } + resp, blocked := s.buildResponse(req) resp.ID = req.ID resp.Data = nil if _, err := io.Copy(w, resp); err != nil { - slog.Error("write response failed", - "err", err, - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - ) + slog.Error("write failed", "err", err, "qname", qname) return } - slog.Info("query served", - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - "rcode", dns.RcodeToString[resp.Rcode], - "blocked", blocked, - ) + slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype], + "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked) } - func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { m := new(dns.Msg) @@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { m.Question = req.Question return m, false } + q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(qname) { - return s.blockedResponse(req), true - } - - if s.cache != nil { - key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} - if cached, ok := s.cache.Get(key); ok { - cached.ID = req.ID - return cached, false - } - } - ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", qname, "qtype", dns.TypeToString[qtype], ) @@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m } + +func parseClientIP(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + } + host, _, _ := net.SplitHostPort(addr.String()) + return net.ParseIP(host) +} + +func parseHTTPClientIP(addr string) net.IP { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return net.ParseIP(addr) + } + return net.ParseIP(host) +} + +func (s *Server) isAllowed(ip net.IP) bool { + if ip == nil { + return false + } + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + for _, n := range s.aclNets { + if n.Contains(ip) { + return true + } + } + return false +} + +func (s *Server) rateLimit(ip string) bool { + if s.rateLimiter == nil { + return true + } + return s.rateLimiter.allow(ip) +} + +func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) { + resp := new(dns.Msg) + resp.Response = true + resp.ID = req.ID + resp.Question = req.Question + resp.Rcode = dns.RcodeRefused + io.Copy(w, resp) +} |
