From 01e05e8df5f56d605dfd75456a424527e76a2955 Mon Sep 17 00:00:00 2001 From: radhitya Date: Sat, 13 Jun 2026 12:46:38 +0700 Subject: dns codec, udp server (reuseport, 4096 buffers), tcp server, doh listener post get without tls, concurrent, ends0, fuzz) --- internal/server/server.go | 147 ++++++++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 69 deletions(-) (limited to 'internal/server/server.go') diff --git a/internal/server/server.go b/internal/server/server.go index 01bda09..f40648e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,50 +3,95 @@ package server import ( "context" "log/slog" + "net/http" "time" "github.com/miekg/dns" ) type Server struct { - addr string logger *slog.Logger - inner *dns.Server + udp *dns.Server + tcp *dns.Server + doh *http.Server } -func New(addr string, logger *slog.Logger) (*Server, error) { +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger) (*Server, error) { mux := dns.NewServeMux() mux.HandleFunc(".", handleQuery) - inner := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - UDPSize: 4096, - ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, + s := &Server{logger: logger} + + if udpAddr != "" { + s.udp = &dns.Server{ + Addr: udpAddr, + Net: "udp", + Handler: mux, + UDPSize: 4096, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + ReusePort: true, + } } - return &Server{ - addr: addr, - logger: logger, - inner: inner, - }, nil -} + if tcpAddr != "" { + s.tcp = &dns.Server{ + Addr: tcpAddr, + Net: "tcp", + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + } + + if dohAddr != "" { + dohMux := http.NewServeMux() + dohMux.HandleFunc("/dns-query", s.dohHandler) + s.doh = &http.Server{ + Addr: dohAddr, + Handler: dohMux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + } + return s, nil +} func (s *Server) Run(ctx context.Context) error { - errCh := make(chan error, 1) - go func() { - s.logger.Info("udp listener active", "addr", s.addr) - errCh <- s.inner.ListenAndServe() - }() + errCh := make(chan error, 3) + if s.udp != nil { + go func() { + s.logger.Info("udp listener active", "addr", s.udp.Addr) + errCh <- s.udp.ListenAndServe() + }() + } + + if s.tcp != nil { + go func() { + s.logger.Info("tcp listener active", "addr", s.tcp.Addr) + errCh <- s.tcp.ListenAndServe() + }() + } + + if s.doh != nil { + go func() { + s.logger.Info("doh listener active", "addr", s.doh.Addr) + errCh <- s.doh.ListenAndServe() + }() + } select { case <-ctx.Done(): - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5 *time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := s.inner.ShutdownContext(shutdownCtx); err != nil { - s.logger.Error("graceful shutdown failed", "err", err) - return err + + if s.udp != nil { + s.udp.ShutdownContext(shutdownCtx) + } + if s.tcp != nil { + s.tcp.ShutdownContext(shutdownCtx) + } + if s.doh != nil { + s.doh.Shutdown(shutdownCtx) } return ctx.Err() case err := <-errCh: @@ -55,50 +100,14 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) Close() error { - return s.inner.Shutdown() -} - -func handleQuery(w dns.ResponseWriter, req *dns.Msg) { - if len(req.Question) == 0 { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - _ = w.WriteMsg(m) - return + if s.udp != nil { + s.udp.Shutdown() } - - q := req.Question[0] - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = false - - if q.Name == "example.com." && q.Qtype == dns.TypeA { - resp.Answer = []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 60, - }, - A: []byte{127,0,0,1}, - }, - } - } else { - resp.Rcode = dns.RcodeNameError + if s.tcp != nil { + s.tcp.Shutdown() } - - if err := w.WriteMsg(resp); err != nil { - slog.Error("write response failed", - "err", err, - "qname", q.Name, - "qtype", dns.TypeToString[q.Qtype], - ) - return -} -slog.Info("query served", -"qname", q.Name, -"qtype", dns.TypeToString[q.Qtype], -"rcode", dns.RcodeToString[resp.Rcode], -"client", w.RemoteAddr().String(), -) + if s.doh != nil { + s.doh.Close() + } + return nil } -- cgit v1.2.3