diff options
Diffstat (limited to 'internal/server/handler.go')
| -rw-r--r-- | internal/server/handler.go | 77 |
1 files changed, 51 insertions, 26 deletions
diff --git a/internal/server/handler.go b/internal/server/handler.go index 5f873d4..d0aa705 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,36 +2,43 @@ package server import ( "context" + "io" "log/slog" - "net" + "net/netip" "time" - "github.com/miekg/dns" + "codeberg.org/miekg/dns" + "codeberg.org/miekg/dns/rdata" "linum/internal/blocklist" "linum/internal/cache" ) -func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { +func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { if len(req.Question) == 0 { m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - _ = w.WriteMsg(m) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + io.Copy(w, m) return } resp, blocked := s.buildResponse(req) - if err := w.WriteMsg(resp); err != nil { + resp.ID = req.ID + resp.Data = nil + if _, err := io.Copy(w, resp); err != nil { slog.Error("write response failed", "err", err, - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], ) return } slog.Info("query served", - "qname", req.Question[0].Name, - "qtype", dns.TypeToString[req.Question[0].Qtype], + "qname", req.Question[0].Header().Name, + "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked, ) @@ -39,18 +46,26 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { - return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false + m := new(dns.Msg) + m.Rcode = dns.RcodeFormatError + m.Response = true + m.ID = req.ID + m.Question = req.Question + return m, false } q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) { + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { return s.blockedResponse(req), true } if s.cache != nil { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} if cached, ok := s.cache.Get(key); ok { - cached.Id = req.Id + cached.ID = req.ID return cached, false } } @@ -58,22 +73,25 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() - reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype) + reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", "err", err, - "qname", q.Name, - "qtype", dns.TypeToString[q.Qtype], + "qname", qname, + "qtype", dns.TypeToString[qtype], ) m := new(dns.Msg) - m.SetRcode(req, dns.RcodeServerFailure) + m.Rcode = dns.RcodeServerFailure + m.Response = true + m.ID = req.ID + m.Question = req.Question return m, false } - reply.Id = req.Id + reply.ID = req.ID if s.cache != nil && reply.Rcode != dns.RcodeServerFailure { - key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} s.cache.Set(key, reply, 0) } @@ -82,7 +100,11 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { m := new(dns.Msg) - m.SetReply(req) + m.Response = true + m.ID = req.ID + m.Opcode = req.Opcode + m.RecursionDesired = req.RecursionDesired + m.Question = req.Question m.Authoritative = true if s.blocklist.Response() == blocklist.ResponseNXDOMAIN { @@ -91,16 +113,19 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { } q := req.Question[0] - switch q.Qtype { + qname := q.Header().Name + qtype := dns.RRToType(q) + + switch qtype { case dns.TypeA: m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.IPv4(0, 0, 0, 0), + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + A: rdata.A{Addr: netip.AddrFrom4([4]byte{})}, }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60}, - AAAA: net.IPv6zero, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m |
