summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/doh.go2
-rw-r--r--internal/server/handler.go37
-rw-r--r--internal/server/server.go17
-rw-r--r--internal/server/server_test.go83
4 files changed, 65 insertions, 74 deletions
diff --git a/internal/server/doh.go b/internal/server/doh.go
index e9cf466..3f5a538 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -52,7 +52,7 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
return
}
- resp := buildResponse(msg)
+ resp := s.buildResponse(msg)
packed, err := resp.Pack()
if err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
diff --git a/internal/server/handler.go b/internal/server/handler.go
index 2e2f08b..bac1c81 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -1,11 +1,13 @@
package server
import (
+ "context"
"github.com/miekg/dns"
"log/slog"
+ "time"
)
-func handleQuery(w dns.ResponseWriter, req *dns.Msg) {
+func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeFormatError)
@@ -13,7 +15,7 @@ func handleQuery(w dns.ResponseWriter, req *dns.Msg) {
return
}
- resp := buildResponse(req)
+ resp := s.buildResponse(req)
if err := w.WriteMsg(resp); err != nil {
slog.Error("write response failed",
@@ -31,7 +33,7 @@ func handleQuery(w dns.ResponseWriter, req *dns.Msg) {
)
}
-func buildResponse(req *dns.Msg) *dns.Msg {
+func (s *Server) buildResponse(req *dns.Msg) *dns.Msg {
if len(req.Question) == 0 {
m := new(dns.Msg)
m.SetRcode(req, dns.RcodeFormatError)
@@ -46,20 +48,19 @@ func buildResponse(req *dns.Msg) *dns.Msg {
resp.SetEdns0(4096, false)
}
- if q.Name == "example.com." && q.Qtype == dns.TypeA {
- resp.Answer = []dns.RR{
- &dns.A{
- Hdr: dns.RR_Header{
- Name: q.Name,
- Rrtype: dns.TypeA,
- Class: dns.ClassINET,
- Ttl: 60,
- },
- A: []byte{127, 0, 0, 1},
- },
- }
- } else {
- resp.Rcode = dns.RcodeNameError
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
+ if err != nil {
+ slog.Error("resolution failed",
+ "err", err,
+ "qname", q.Name,
+ "qtype", dns.TypeToString[q.Qtype],
+ )
+ resp.Rcode = dns.RcodeServerFailure
+ return resp
}
- return resp
+
+ return reply
}
diff --git a/internal/server/server.go b/internal/server/server.go
index f40648e..3114073 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -7,20 +7,21 @@ import (
"time"
"github.com/miekg/dns"
+ "sdns/internal/resolver"
)
type Server struct {
- logger *slog.Logger
- udp *dns.Server
- tcp *dns.Server
- doh *http.Server
+ logger *slog.Logger
+ resolver *resolver.Resolver
+ udp *dns.Server
+ tcp *dns.Server
+ doh *http.Server
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger) (*Server, error) {
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver) (*Server, error) {
+ s := &Server{logger: logger, resolver: r}
mux := dns.NewServeMux()
- mux.HandleFunc(".", handleQuery)
-
- s := &Server{logger: logger}
+ mux.HandleFunc(".", s.handleQuery)
if udpAddr != "" {
s.udp = &dns.Server{
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index eaf0190..6fc2092 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -1,12 +1,25 @@
package server
import (
+ "log/slog"
"testing"
+ "time"
"github.com/miekg/dns"
+ "sdns/internal/resolver"
)
+func testServer(t *testing.T) *Server {
+ t.Helper()
+ r := resolver.New(
+ resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
+ resolver.WithTimeout(50*time.Millisecond),
+ )
+ return &Server{logger: slog.Default(), resolver: r}
+}
+
func TestBuildResponse(t *testing.T) {
+ s := testServer(t)
tests := []struct {
name string
req *dns.Msg
@@ -15,39 +28,6 @@ func TestBuildResponse(t *testing.T) {
wantEdns0 bool
}{
{
- name: "example.com A returns 127.0.0.1",
- req: func() *dns.Msg {
- m := new(dns.Msg)
- m.SetQuestion("example.com.", dns.TypeA)
- return m
- }(),
- wantRcode: dns.RcodeSuccess,
- wantAnswers: 1,
- wantEdns0: false,
- },
- {
- name: "google.com A returns NXDOMAIN",
- req: func() *dns.Msg {
- m := new(dns.Msg)
- m.SetQuestion("google.com.", dns.TypeA)
- return m
- }(),
- wantRcode: dns.RcodeNameError,
- wantAnswers: 0,
- wantEdns0: false,
- },
- {
- name: "other.com A returns NXDOMAIN",
- req: func() *dns.Msg {
- m := new(dns.Msg)
- m.SetQuestion("other.com.", dns.TypeA)
- return m
- }(),
- wantRcode: dns.RcodeNameError,
- wantAnswers: 0,
- wantEdns0: false,
- },
- {
name: "no questions returns FORMERR",
req: func() *dns.Msg {
return new(dns.Msg)
@@ -56,23 +36,11 @@ func TestBuildResponse(t *testing.T) {
wantAnswers: 0,
wantEdns0: false,
},
- {
- name: "EDNS0 query preserved with 4096 buffer",
- req: func() *dns.Msg {
- m := new(dns.Msg)
- m.SetQuestion("example.com.", dns.TypeA)
- m.SetEdns0(1232, true)
- return m
- }(),
- wantRcode: dns.RcodeSuccess,
- wantAnswers: 1,
- wantEdns0: true,
- },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- resp := buildResponse(tt.req)
+ resp := s.buildResponse(tt.req)
if resp.Rcode != tt.wantRcode {
t.Errorf("rcode: got %d, want %d", resp.Rcode, tt.wantRcode)
}
@@ -90,7 +58,28 @@ func TestBuildResponse(t *testing.T) {
}
}
+func TestBuildResponseWithQuery(t *testing.T) {
+ s := testServer(t)
+ // Valid query → must not panic, rcode must be valid
+ m := new(dns.Msg)
+ m.SetQuestion("example.com.", dns.TypeA)
+ resp := s.buildResponse(m)
+ if resp == nil {
+ t.Fatal("buildResponse returned nil")
+ }
+ if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeServerFailure {
+ t.Errorf("expected success or server failure, got %d", resp.Rcode)
+ }
+}
+
func FuzzBuildResponse(f *testing.F) {
+ s := &Server{logger: slog.Default()}
+ // For fuzz, use a resolver that won't make real network calls
+ s.resolver = resolver.New(
+ resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
+ resolver.WithTimeout(10*time.Millisecond),
+ )
+
seed := []byte{
0x00, 0x00, // ID
0x01, 0x00, // flags: RD
@@ -111,7 +100,7 @@ func FuzzBuildResponse(f *testing.F) {
if err := msg.Unpack(data); err != nil {
return
}
- resp := buildResponse(msg)
+ resp := s.buildResponse(msg)
if resp == nil {
t.Fatal("buildResponse returned nil")
}