package server import ( "codeberg.org/miekg/dns" "codeberg.org/miekg/dns/rdata" "context" "errors" "io" "linum/internal/blocklist" "linum/internal/cache" "linum/internal/resolver" "log/slog" "net" "net/netip" "runtime/debug" "time" ) func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { defer func() { if r := recover(); r != nil { slog.Error("panic in query handler", "recover", r, "stack", string(debug.Stack())) m := new(dns.Msg) m.Response = true m.Rcode = dns.RcodeServerFailure if req != nil { m.ID = req.ID m.Question = req.Question } io.Copy(w, m) } }() clientIP := parseClientIP(w.RemoteAddr()) if !s.isAllowed(clientIP) { slog.Warn("query denied by ACL", "client", clientIP) s.writeRefused(w, req) return } if !s.rateLimit(clientIP.String()) { slog.Warn("query rate limited", "client", clientIP) s.writeRefused(w, req) return } if len(req.Question) == 0 { m := new(dns.Msg) m.Rcode = dns.RcodeFormatError m.Response = true m.ID = req.ID m.Question = req.Question io.Copy(w, m) return } q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class if s.blocklist != nil && s.blocklist.IsBlocked(qname) { resp := s.blockedResponse(req) resp.ID = req.ID io.Copy(w, resp) slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], "blocked", true) return } if s.cache != nil { key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} if packed, ok := s.cache.Get(key); ok { packed[0] = byte(req.ID >> 8) packed[1] = byte(req.ID) io.Copy(w, &dns.Msg{Data: packed}) slog.Debug("query served", "qname", qname, "qtype", dns.TypeToString[qtype], "cached", true) return } } resp, blocked := s.buildResponse(req) resp.ID = req.ID resp.Data = nil if _, err := io.Copy(w, resp); err != nil { slog.Error("write failed", "err", err, "qname", qname) return } slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked) } func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { m := new(dns.Msg) m.Rcode = dns.RcodeFormatError m.Response = true m.ID = req.ID m.Question = req.Question return m, false } q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { rcode := dns.RcodeServerFailure if errors.Is(err, resolver.ErrUnsupportedType) { rcode = dns.RcodeNotImplemented } slog.Error("resolution failed", "err", err, "qname", qname, "qtype", dns.TypeToString[qtype], ) m := new(dns.Msg) m.Rcode = uint16(rcode) m.Response = true m.ID = req.ID m.Question = req.Question return m, false } reply.ID = req.ID if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} s.cache.Set(key, reply, 0) } return reply, false } func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { m := new(dns.Msg) m.Response = true m.ID = req.ID m.Opcode = req.Opcode m.RecursionDesired = req.RecursionDesired m.Question = req.Question m.Authoritative = true if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { m.Rcode = dns.RcodeNameError return m } q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) switch qtype { case dns.TypeA: m.Answer = append(m.Answer, &dns.A{ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, A: rdata.A{Addr: netip.AddrFrom4([4]byte{})}, }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m } func parseClientIP(addr net.Addr) net.IP { switch a := addr.(type) { case *net.UDPAddr: return a.IP case *net.TCPAddr: return a.IP } host, _, _ := net.SplitHostPort(addr.String()) return net.ParseIP(host) } func parseHTTPClientIP(addr string) net.IP { host, _, err := net.SplitHostPort(addr) if err != nil { return net.ParseIP(addr) } return net.ParseIP(host) } func (s *Server) isAllowed(ip net.IP) bool { if ip == nil { return false } if ip4 := ip.To4(); ip4 != nil { ip = ip4 } for _, n := range s.aclNets { if n.Contains(ip) { return true } } return false } func (s *Server) rateLimit(ip string) bool { if s.rateLimiter == nil { return true } return s.rateLimiter.allow(ip) } func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) { resp := new(dns.Msg) resp.Response = true resp.ID = req.ID resp.Question = req.Question resp.Rcode = dns.RcodeRefused io.Copy(w, resp) }