From 01e05e8df5f56d605dfd75456a424527e76a2955 Mon Sep 17 00:00:00 2001 From: radhitya Date: Sat, 13 Jun 2026 12:46:38 +0700 Subject: dns codec, udp server (reuseport, 4096 buffers), tcp server, doh listener post get without tls, concurrent, ends0, fuzz) --- internal/server/doh.go | 75 +++++++++++++++++++++ internal/server/handler.go | 65 ++++++++++++++++++ internal/server/server.go | 147 ++++++++++++++++++++++------------------- internal/server/server_test.go | 122 ++++++++++++++++++++++++++++++++++ 4 files changed, 340 insertions(+), 69 deletions(-) create mode 100644 internal/server/doh.go create mode 100644 internal/server/handler.go create mode 100644 internal/server/server_test.go (limited to 'internal') diff --git a/internal/server/doh.go b/internal/server/doh.go new file mode 100644 index 0000000..e9cf466 --- /dev/null +++ b/internal/server/doh.go @@ -0,0 +1,75 @@ +package server + +import ( + "encoding/base64" + "github.com/miekg/dns" + "io" + "log/slog" + "net/http" +) + +func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { + var raw []byte + + switch r.Method { + case http.MethodPost: + ct := r.Header.Get("Content-Type") + if ct != "application/dns-message" { + http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 65535)) + if err != nil { + http.Error(w, "read body", http.StatusBadRequest) + return + } + raw = body + case http.MethodGet: + param := r.URL.Query().Get("dns") + if param == "" { + http.Error(w, "missing dns param", http.StatusBadRequest) + return + } + decoded, err := base64.RawURLEncoding.DecodeString(param) + if err != nil { + http.Error(w, "invalid base64url", http.StatusBadRequest) + return + } + raw = decoded + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + msg := new(dns.Msg) + if err := msg.Unpack(raw); err != nil { + http.Error(w, "invalid dns message", http.StatusBadRequest) + return + } + + if len(msg.Question) == 0 { + http.Error(w, "no question", http.StatusBadRequest) + return + } + + resp := buildResponse(msg) + packed, err := resp.Pack() + if err != nil { + http.Error(w, "pack response", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/dns-message") + w.Header().Set("Cache-Control", "no-cache, max-age=0") + if _, err := w.Write(packed); err != nil { + slog.Error("doh write failed", "err", err) + return + } + + slog.Info("doh query served", + "qname", msg.Question[0].Name, + "qtype", dns.TypeToString[msg.Question[0].Qtype], + "rcode", dns.RcodeToString[resp.Rcode], + "client", r.RemoteAddr, + ) +} diff --git a/internal/server/handler.go b/internal/server/handler.go new file mode 100644 index 0000000..2e2f08b --- /dev/null +++ b/internal/server/handler.go @@ -0,0 +1,65 @@ +package server + +import ( + "github.com/miekg/dns" + "log/slog" +) + +func handleQuery(w dns.ResponseWriter, req *dns.Msg) { + if len(req.Question) == 0 { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeFormatError) + _ = w.WriteMsg(m) + return + } + + resp := buildResponse(req) + + if err := w.WriteMsg(resp); err != nil { + slog.Error("write response failed", + "err", err, + "qname", req.Question[0].Name, + "qtype", dns.TypeToString[req.Question[0].Qtype], + ) + return + } + slog.Info("query served", + "qname", req.Question[0].Name, + "qtype", dns.TypeToString[req.Question[0].Qtype], + "rcode", dns.RcodeToString[resp.Rcode], + "client", w.RemoteAddr().String(), + ) +} + +func buildResponse(req *dns.Msg) *dns.Msg { + if len(req.Question) == 0 { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeFormatError) + return m + } + q := req.Question[0] + resp := new(dns.Msg) + resp.SetReply(req) + resp.Authoritative = false + + if opt := req.IsEdns0(); opt != nil { + resp.SetEdns0(4096, false) + } + + if q.Name == "example.com." && q.Qtype == dns.TypeA { + resp.Answer = []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 60, + }, + A: []byte{127, 0, 0, 1}, + }, + } + } else { + resp.Rcode = dns.RcodeNameError + } + return resp +} diff --git a/internal/server/server.go b/internal/server/server.go index 01bda09..f40648e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,50 +3,95 @@ package server import ( "context" "log/slog" + "net/http" "time" "github.com/miekg/dns" ) type Server struct { - addr string logger *slog.Logger - inner *dns.Server + udp *dns.Server + tcp *dns.Server + doh *http.Server } -func New(addr string, logger *slog.Logger) (*Server, error) { +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger) (*Server, error) { mux := dns.NewServeMux() mux.HandleFunc(".", handleQuery) - inner := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - UDPSize: 4096, - ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, + s := &Server{logger: logger} + + if udpAddr != "" { + s.udp = &dns.Server{ + Addr: udpAddr, + Net: "udp", + Handler: mux, + UDPSize: 4096, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + ReusePort: true, + } } - return &Server{ - addr: addr, - logger: logger, - inner: inner, - }, nil -} + if tcpAddr != "" { + s.tcp = &dns.Server{ + Addr: tcpAddr, + Net: "tcp", + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + } + + if dohAddr != "" { + dohMux := http.NewServeMux() + dohMux.HandleFunc("/dns-query", s.dohHandler) + s.doh = &http.Server{ + Addr: dohAddr, + Handler: dohMux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + } + return s, nil +} func (s *Server) Run(ctx context.Context) error { - errCh := make(chan error, 1) - go func() { - s.logger.Info("udp listener active", "addr", s.addr) - errCh <- s.inner.ListenAndServe() - }() + errCh := make(chan error, 3) + if s.udp != nil { + go func() { + s.logger.Info("udp listener active", "addr", s.udp.Addr) + errCh <- s.udp.ListenAndServe() + }() + } + + if s.tcp != nil { + go func() { + s.logger.Info("tcp listener active", "addr", s.tcp.Addr) + errCh <- s.tcp.ListenAndServe() + }() + } + + if s.doh != nil { + go func() { + s.logger.Info("doh listener active", "addr", s.doh.Addr) + errCh <- s.doh.ListenAndServe() + }() + } select { case <-ctx.Done(): - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5 *time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := s.inner.ShutdownContext(shutdownCtx); err != nil { - s.logger.Error("graceful shutdown failed", "err", err) - return err + + if s.udp != nil { + s.udp.ShutdownContext(shutdownCtx) + } + if s.tcp != nil { + s.tcp.ShutdownContext(shutdownCtx) + } + if s.doh != nil { + s.doh.Shutdown(shutdownCtx) } return ctx.Err() case err := <-errCh: @@ -55,50 +100,14 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) Close() error { - return s.inner.Shutdown() -} - -func handleQuery(w dns.ResponseWriter, req *dns.Msg) { - if len(req.Question) == 0 { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeFormatError) - _ = w.WriteMsg(m) - return + if s.udp != nil { + s.udp.Shutdown() } - - q := req.Question[0] - resp := new(dns.Msg) - resp.SetReply(req) - resp.Authoritative = false - - if q.Name == "example.com." && q.Qtype == dns.TypeA { - resp.Answer = []dns.RR{ - &dns.A{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 60, - }, - A: []byte{127,0,0,1}, - }, - } - } else { - resp.Rcode = dns.RcodeNameError + if s.tcp != nil { + s.tcp.Shutdown() } - - if err := w.WriteMsg(resp); err != nil { - slog.Error("write response failed", - "err", err, - "qname", q.Name, - "qtype", dns.TypeToString[q.Qtype], - ) - return -} -slog.Info("query served", -"qname", q.Name, -"qtype", dns.TypeToString[q.Qtype], -"rcode", dns.RcodeToString[resp.Rcode], -"client", w.RemoteAddr().String(), -) + if s.doh != nil { + s.doh.Close() + } + return nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..eaf0190 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,122 @@ +package server + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestBuildResponse(t *testing.T) { + tests := []struct { + name string + req *dns.Msg + wantRcode int + wantAnswers int + wantEdns0 bool + }{ + { + name: "example.com A returns 127.0.0.1", + req: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + }(), + wantRcode: dns.RcodeSuccess, + wantAnswers: 1, + wantEdns0: false, + }, + { + name: "google.com A returns NXDOMAIN", + req: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("google.com.", dns.TypeA) + return m + }(), + wantRcode: dns.RcodeNameError, + wantAnswers: 0, + wantEdns0: false, + }, + { + name: "other.com A returns NXDOMAIN", + req: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("other.com.", dns.TypeA) + return m + }(), + wantRcode: dns.RcodeNameError, + wantAnswers: 0, + wantEdns0: false, + }, + { + name: "no questions returns FORMERR", + req: func() *dns.Msg { + return new(dns.Msg) + }(), + wantRcode: dns.RcodeFormatError, + wantAnswers: 0, + wantEdns0: false, + }, + { + name: "EDNS0 query preserved with 4096 buffer", + req: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetEdns0(1232, true) + return m + }(), + wantRcode: dns.RcodeSuccess, + wantAnswers: 1, + wantEdns0: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := buildResponse(tt.req) + if resp.Rcode != tt.wantRcode { + t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode) + } + if len(resp.Answer) != tt.wantAnswers { + t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers) + } + if tt.wantEdns0 { + if opt := resp.IsEdns0(); opt == nil { + t.Error("expected EDNS0 in response, got none") + } else if opt.UDPSize() != 4096 { + t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize()) + } + } + }) + } +} + +func FuzzBuildResponse(f *testing.F) { + seed := []byte{ + 0x00, 0x00, // ID + 0x01, 0x00, // flags: RD + 0x00, 0x01, // QDCOUNT: 1 + 0x00, 0x00, // ANCOUNT + 0x00, 0x00, // NSCOUNT + 0x00, 0x00, // ARCOUNT + // Question: example.com A + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, + 0x03, 0x63, 0x6f, 0x6d, + 0x00, + 0x00, 0x01, // QTYPE: A + 0x00, 0x01, // QCLASS: IN + } + f.Add(seed) + f.Fuzz(func(t *testing.T, data []byte) { + msg := new(dns.Msg) + if err := msg.Unpack(data); err != nil { + return + } + resp := buildResponse(msg) + if resp == nil { + t.Fatal("buildResponse returned nil") + } + if _, err := resp.Pack(); err != nil { + t.Errorf("pack failed: %v", err) + } + }) +} -- cgit v1.2.3