summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/server/doh.go75
-rw-r--r--internal/server/handler.go65
-rw-r--r--internal/server/server.go147
-rw-r--r--internal/server/server_test.go122
-rw-r--r--main.go21
5 files changed, 356 insertions, 74 deletions
diff --git a/internal/server/doh.go b/internal/server/doh.go
new file mode 100644
index 0000000..e9cf466
--- /dev/null
+++ b/internal/server/doh.go
@@ -0,0 +1,75 @@
+package server
+
+import (
+ "encoding/base64"
+ "github.com/miekg/dns"
+ "io"
+ "log/slog"
+ "net/http"
+)
+
+func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
+ var raw []byte
+
+ switch r.Method {
+ case http.MethodPost:
+ ct := r.Header.Get("Content-Type")
+ if ct != "application/dns-message" {
+ http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType)
+ return
+ }
+ body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 65535))
+ if err != nil {
+ http.Error(w, "read body", http.StatusBadRequest)
+ return
+ }
+ raw = body
+ case http.MethodGet:
+ param := r.URL.Query().Get("dns")
+ if param == "" {
+ http.Error(w, "missing dns param", http.StatusBadRequest)
+ return
+ }
+ decoded, err := base64.RawURLEncoding.DecodeString(param)
+ if err != nil {
+ http.Error(w, "invalid base64url", http.StatusBadRequest)
+ return
+ }
+ raw = decoded
+ default:
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ msg := new(dns.Msg)
+ if err := msg.Unpack(raw); err != nil {
+ http.Error(w, "invalid dns message", http.StatusBadRequest)
+ return
+ }
+
+ if len(msg.Question) == 0 {
+ http.Error(w, "no question", http.StatusBadRequest)
+ return
+ }
+
+ resp := buildResponse(msg)
+ packed, err := resp.Pack()
+ if err != nil {
+ http.Error(w, "pack response", http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/dns-message")
+ w.Header().Set("Cache-Control", "no-cache, max-age=0")
+ if _, err := w.Write(packed); err != nil {
+ slog.Error("doh write failed", "err", err)
+ return
+ }
+
+ slog.Info("doh query served",
+ "qname", msg.Question[0].Name,
+ "qtype", dns.TypeToString[msg.Question[0].Qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ "client", r.RemoteAddr,
+ )
+}
diff --git a/internal/server/handler.go b/internal/server/handler.go
new file mode 100644
index 0000000..2e2f08b
--- /dev/null
+++ b/internal/server/handler.go
@@ -0,0 +1,65 @@
+package server
+
+import (
+ "github.com/miekg/dns"
+ "log/slog"
+)
+
+func handleQuery(w dns.ResponseWriter, req *dns.Msg) {
+ if len(req.Question) == 0 {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeFormatError)
+ _ = w.WriteMsg(m)
+ return
+ }
+
+ resp := buildResponse(req)
+
+ if err := w.WriteMsg(resp); err != nil {
+ slog.Error("write response failed",
+ "err", err,
+ "qname", req.Question[0].Name,
+ "qtype", dns.TypeToString[req.Question[0].Qtype],
+ )
+ return
+ }
+ slog.Info("query served",
+ "qname", req.Question[0].Name,
+ "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ "client", w.RemoteAddr().String(),
+ )
+}
+
+func buildResponse(req *dns.Msg) *dns.Msg {
+ if len(req.Question) == 0 {
+ m := new(dns.Msg)
+ m.SetRcode(req, dns.RcodeFormatError)
+ return m
+ }
+ q := req.Question[0]
+ resp := new(dns.Msg)
+ resp.SetReply(req)
+ resp.Authoritative = false
+
+ if opt := req.IsEdns0(); opt != nil {
+ resp.SetEdns0(4096, false)
+ }
+
+ if q.Name == "example.com." && q.Qtype == dns.TypeA {
+ resp.Answer = []dns.RR{
+ &dns.A{
+ Hdr: dns.RR_Header{
+ Name: q.Name,
+ Rrtype: dns.TypeA,
+ Class: dns.ClassINET,
+ Ttl: 60,
+ },
+ A: []byte{127, 0, 0, 1},
+ },
+ }
+ } else {
+ resp.Rcode = dns.RcodeNameError
+ }
+ return resp
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index 01bda09..f40648e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -3,50 +3,95 @@ package server
import (
"context"
"log/slog"
+ "net/http"
"time"
"github.com/miekg/dns"
)
type Server struct {
- addr string
logger *slog.Logger
- inner *dns.Server
+ udp *dns.Server
+ tcp *dns.Server
+ doh *http.Server
}
-func New(addr string, logger *slog.Logger) (*Server, error) {
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger) (*Server, error) {
mux := dns.NewServeMux()
mux.HandleFunc(".", handleQuery)
- inner := &dns.Server{
- Addr: addr,
- Net: "udp",
- Handler: mux,
- UDPSize: 4096,
- ReadTimeout: 5 * time.Second,
- WriteTimeout: 5 * time.Second,
+ s := &Server{logger: logger}
+
+ if udpAddr != "" {
+ s.udp = &dns.Server{
+ Addr: udpAddr,
+ Net: "udp",
+ Handler: mux,
+ UDPSize: 4096,
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ ReusePort: true,
+ }
}
- return &Server{
- addr: addr,
- logger: logger,
- inner: inner,
- }, nil
-}
+ if tcpAddr != "" {
+ s.tcp = &dns.Server{
+ Addr: tcpAddr,
+ Net: "tcp",
+ Handler: mux,
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ }
+ }
+
+ if dohAddr != "" {
+ dohMux := http.NewServeMux()
+ dohMux.HandleFunc("/dns-query", s.dohHandler)
+ s.doh = &http.Server{
+ Addr: dohAddr,
+ Handler: dohMux,
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 5 * time.Second,
+ }
+ }
+ return s, nil
+}
func (s *Server) Run(ctx context.Context) error {
- errCh := make(chan error, 1)
- go func() {
- s.logger.Info("udp listener active", "addr", s.addr)
- errCh <- s.inner.ListenAndServe()
- }()
+ errCh := make(chan error, 3)
+ if s.udp != nil {
+ go func() {
+ s.logger.Info("udp listener active", "addr", s.udp.Addr)
+ errCh <- s.udp.ListenAndServe()
+ }()
+ }
+
+ if s.tcp != nil {
+ go func() {
+ s.logger.Info("tcp listener active", "addr", s.tcp.Addr)
+ errCh <- s.tcp.ListenAndServe()
+ }()
+ }
+
+ if s.doh != nil {
+ go func() {
+ s.logger.Info("doh listener active", "addr", s.doh.Addr)
+ errCh <- s.doh.ListenAndServe()
+ }()
+ }
select {
case <-ctx.Done():
- shutdownCtx, cancel := context.WithTimeout(context.Background(), 5 *time.Second)
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
- if err := s.inner.ShutdownContext(shutdownCtx); err != nil {
- s.logger.Error("graceful shutdown failed", "err", err)
- return err
+
+ if s.udp != nil {
+ s.udp.ShutdownContext(shutdownCtx)
+ }
+ if s.tcp != nil {
+ s.tcp.ShutdownContext(shutdownCtx)
+ }
+ if s.doh != nil {
+ s.doh.Shutdown(shutdownCtx)
}
return ctx.Err()
case err := <-errCh:
@@ -55,50 +100,14 @@ func (s *Server) Run(ctx context.Context) error {
}
func (s *Server) Close() error {
- return s.inner.Shutdown()
-}
-
-func handleQuery(w dns.ResponseWriter, req *dns.Msg) {
- if len(req.Question) == 0 {
- m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- _ = w.WriteMsg(m)
- return
+ if s.udp != nil {
+ s.udp.Shutdown()
}
-
- q := req.Question[0]
- resp := new(dns.Msg)
- resp.SetReply(req)
- resp.Authoritative = false
-
- if q.Name == "example.com." && q.Qtype == dns.TypeA {
- resp.Answer = []dns.RR{
- &dns.A{
- Hdr: dns.RR_Header{
- Name: q.Name,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 60,
- },
- A: []byte{127,0,0,1},
- },
- }
- } else {
- resp.Rcode = dns.RcodeNameError
+ if s.tcp != nil {
+ s.tcp.Shutdown()
}
-
- if err := w.WriteMsg(resp); err != nil {
- slog.Error("write response failed",
- "err", err,
- "qname", q.Name,
- "qtype", dns.TypeToString[q.Qtype],
- )
- return
-}
-slog.Info("query served",
-"qname", q.Name,
-"qtype", dns.TypeToString[q.Qtype],
-"rcode", dns.RcodeToString[resp.Rcode],
-"client", w.RemoteAddr().String(),
-)
+ if s.doh != nil {
+ s.doh.Close()
+ }
+ return nil
}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
new file mode 100644
index 0000000..eaf0190
--- /dev/null
+++ b/internal/server/server_test.go
@@ -0,0 +1,122 @@
+package server
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+)
+
+func TestBuildResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ req *dns.Msg
+ wantRcode int
+ wantAnswers int
+ wantEdns0 bool
+ }{
+ {
+ name: "example.com A returns 127.0.0.1",
+ req: func() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("example.com.", dns.TypeA)
+ return m
+ }(),
+ wantRcode: dns.RcodeSuccess,
+ wantAnswers: 1,
+ wantEdns0: false,
+ },
+ {
+ name: "google.com A returns NXDOMAIN",
+ req: func() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("google.com.", dns.TypeA)
+ return m
+ }(),
+ wantRcode: dns.RcodeNameError,
+ wantAnswers: 0,
+ wantEdns0: false,
+ },
+ {
+ name: "other.com A returns NXDOMAIN",
+ req: func() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("other.com.", dns.TypeA)
+ return m
+ }(),
+ wantRcode: dns.RcodeNameError,
+ wantAnswers: 0,
+ wantEdns0: false,
+ },
+ {
+ name: "no questions returns FORMERR",
+ req: func() *dns.Msg {
+ return new(dns.Msg)
+ }(),
+ wantRcode: dns.RcodeFormatError,
+ wantAnswers: 0,
+ wantEdns0: false,
+ },
+ {
+ name: "EDNS0 query preserved with 4096 buffer",
+ req: func() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion("example.com.", dns.TypeA)
+ m.SetEdns0(1232, true)
+ return m
+ }(),
+ wantRcode: dns.RcodeSuccess,
+ wantAnswers: 1,
+ wantEdns0: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ resp := buildResponse(tt.req)
+ if resp.Rcode != tt.wantRcode {
+ t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode)
+ }
+ if len(resp.Answer) != tt.wantAnswers {
+ t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers)
+ }
+ if tt.wantEdns0 {
+ if opt := resp.IsEdns0(); opt == nil {
+ t.Error("expected EDNS0 in response, got none")
+ } else if opt.UDPSize() != 4096 {
+ t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize())
+ }
+ }
+ })
+ }
+}
+
+func FuzzBuildResponse(f *testing.F) {
+ seed := []byte{
+ 0x00, 0x00, // ID
+ 0x01, 0x00, // flags: RD
+ 0x00, 0x01, // QDCOUNT: 1
+ 0x00, 0x00, // ANCOUNT
+ 0x00, 0x00, // NSCOUNT
+ 0x00, 0x00, // ARCOUNT
+ // Question: example.com A
+ 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65,
+ 0x03, 0x63, 0x6f, 0x6d,
+ 0x00,
+ 0x00, 0x01, // QTYPE: A
+ 0x00, 0x01, // QCLASS: IN
+ }
+ f.Add(seed)
+ f.Fuzz(func(t *testing.T, data []byte) {
+ msg := new(dns.Msg)
+ if err := msg.Unpack(data); err != nil {
+ return
+ }
+ resp := buildResponse(msg)
+ if resp == nil {
+ t.Fatal("buildResponse returned nil")
+ }
+ if _, err := resp.Pack(); err != nil {
+ t.Errorf("pack failed: %v", err)
+ }
+ })
+}
diff --git a/main.go b/main.go
index 5d9e526..de6cfb4 100644
--- a/main.go
+++ b/main.go
@@ -14,23 +14,34 @@ func main() {
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
+ slog.SetDefault(logger)
- addr := os.Getenv("SDNS_LISTEN")
- if addr == "" {
- addr = ":5353"
+ udp := os.Getenv("SDNS_LISTEN_UDP")
+ if udp == "" {
+ udp = ":5353"
+ }
+
+ tcp := os.Getenv("SDNS_LISTEN_TCP")
+ if tcp == "" {
+ tcp = ":5353"
+ }
+
+ doh := os.Getenv("SDNS_LISTEN_DOH")
+ if doh == "" {
+ doh = ":8443"
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
- srv, err := server.New(addr, logger)
+ srv, err := server.New(udp, tcp, doh, logger)
if err != nil {
logger.Error("create server failed", "err", err)
os.Exit(1)
}
defer srv.Close()
- logger.Info("sdns starting", "addr", addr)
+ logger.Info("sdns starting", "udp", udp, "tcp", tcp, "doh", doh)
if err := srv.Run(ctx); err != nil && err != context.Canceled {
logger.Error("server stopped with error", "err", err)