summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/doh.go14
-rw-r--r--internal/server/handler.go77
-rw-r--r--internal/server/server.go12
-rw-r--r--internal/server/server_test.go20
4 files changed, 72 insertions, 51 deletions
diff --git a/internal/server/doh.go b/internal/server/doh.go
index b46736e..2f9dfc0 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -2,7 +2,7 @@ package server
import (
"encoding/base64"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"io"
"log/slog"
"net/http"
@@ -42,7 +42,8 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
}
msg := new(dns.Msg)
- if err := msg.Unpack(raw); err != nil {
+ msg.Data = raw
+ if err := msg.Unpack(); err != nil {
http.Error(w, "invalid dns message", http.StatusBadRequest)
return
}
@@ -53,22 +54,21 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
}
resp, _ := s.buildResponse(msg)
- packed, err := resp.Pack()
- if err != nil {
+ if err := resp.Pack(); err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/dns-message")
w.Header().Set("Cache-Control", "no-cache, max-age=0")
- if _, err := w.Write(packed); err != nil {
+ if _, err := w.Write(resp.Data); err != nil {
slog.Error("doh write failed", "err", err)
return
}
slog.Info("doh query served",
- "qname", msg.Question[0].Name,
- "qtype", dns.TypeToString[msg.Question[0].Qtype],
+ "qname", msg.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])],
"rcode", dns.RcodeToString[resp.Rcode],
"client", r.RemoteAddr,
)
diff --git a/internal/server/handler.go b/internal/server/handler.go
index 5f873d4..d0aa705 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -2,36 +2,43 @@ package server
import (
"context"
+ "io"
"log/slog"
- "net"
+ "net/netip"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
+ "codeberg.org/miekg/dns/rdata"
"linum/internal/blocklist"
"linum/internal/cache"
)
-func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
+func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeFormatError)
- _ = w.WriteMsg(m)
+ m.Rcode = dns.RcodeFormatError
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
+ io.Copy(w, m)
return
}
resp, blocked := s.buildResponse(req)
- if err := w.WriteMsg(resp); err != nil {
+ resp.ID = req.ID
+ resp.Data = nil
+ if _, err := io.Copy(w, resp); err != nil {
slog.Error("write response failed",
"err", err,
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "qname", req.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
)
return
}
slog.Info("query served",
- "qname", req.Question[0].Name,
- "qtype", dns.TypeToString[req.Question[0].Qtype],
+ "qname", req.Question[0].Header().Name,
+ "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
"rcode", dns.RcodeToString[resp.Rcode],
"blocked", blocked,
)
@@ -39,18 +46,26 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) {
func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
- return new(dns.Msg).SetRcode(req, dns.RcodeFormatError), false
+ m := new(dns.Msg)
+ m.Rcode = dns.RcodeFormatError
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
+ return m, false
}
q := req.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
- if s.blocklist != nil && s.blocklist.IsBlocked(q.Name) {
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
return s.blockedResponse(req), true
}
if s.cache != nil {
- key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
if cached, ok := s.cache.Get(key); ok {
- cached.Id = req.Id
+ cached.ID = req.ID
return cached, false
}
}
@@ -58,22 +73,25 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
- reply, err := s.resolver.Lookup(ctx, q.Name, q.Qtype)
+ reply, err := s.resolver.Lookup(ctx, qname, qtype)
if err != nil {
slog.Error("resolution failed",
"err", err,
- "qname", q.Name,
- "qtype", dns.TypeToString[q.Qtype],
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
)
m := new(dns.Msg)
- m.SetRcode(req, dns.RcodeServerFailure)
+ m.Rcode = dns.RcodeServerFailure
+ m.Response = true
+ m.ID = req.ID
+ m.Question = req.Question
return m, false
}
- reply.Id = req.Id
+ reply.ID = req.ID
if s.cache != nil && reply.Rcode != dns.RcodeServerFailure {
- key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass}
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
s.cache.Set(key, reply, 0)
}
@@ -82,7 +100,11 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
m := new(dns.Msg)
- m.SetReply(req)
+ m.Response = true
+ m.ID = req.ID
+ m.Opcode = req.Opcode
+ m.RecursionDesired = req.RecursionDesired
+ m.Question = req.Question
m.Authoritative = true
if s.blocklist.Response() == blocklist.ResponseNXDOMAIN {
@@ -91,16 +113,19 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
}
q := req.Question[0]
- switch q.Qtype {
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+
+ switch qtype {
case dns.TypeA:
m.Answer = append(m.Answer, &dns.A{
- Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
- A: net.IPv4(0, 0, 0, 0),
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ A: rdata.A{Addr: netip.AddrFrom4([4]byte{})},
})
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 60},
- AAAA: net.IPv6zero,
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()},
})
}
return m
diff --git a/internal/server/server.go b/internal/server/server.go
index 8f991eb..1aa3256 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,7 +6,7 @@ import (
"net/http"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"linum/internal/resolver"
"linum/internal/blocklist"
"linum/internal/cache"
@@ -41,7 +41,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
Handler: mux,
UDPSize: 4096,
ReadTimeout: 5 * time.Second,
- WriteTimeout: 5 * time.Second,
ReusePort: true,
}
}
@@ -52,7 +51,6 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
Net: "tcp",
Handler: mux,
ReadTimeout: 5 * time.Second,
- WriteTimeout: 5 * time.Second,
}
}
@@ -97,10 +95,10 @@ func (s *Server) Run(ctx context.Context) error {
defer cancel()
if s.udp != nil {
- s.udp.ShutdownContext(shutdownCtx)
+ s.udp.Shutdown(shutdownCtx)
}
if s.tcp != nil {
- s.tcp.ShutdownContext(shutdownCtx)
+ s.tcp.Shutdown(shutdownCtx)
}
if s.doh != nil {
s.doh.Shutdown(shutdownCtx)
@@ -113,10 +111,10 @@ func (s *Server) Run(ctx context.Context) error {
func (s *Server) Close() error {
if s.udp != nil {
- s.udp.Shutdown()
+ s.udp.Shutdown(context.Background())
}
if s.tcp != nil {
- s.tcp.Shutdown()
+ s.tcp.Shutdown(context.Background())
}
if s.doh != nil {
s.doh.Close()
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 42938d1..002aba8 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -6,7 +6,7 @@ import (
"testing"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
"linum/internal/resolver"
)
@@ -27,7 +27,7 @@ func TestBuildResponse(t *testing.T) {
tests := []struct {
name string
req *dns.Msg
- wantRcode int
+ wantRcode uint16
wantAnswers int
wantEdns0 bool
}{
@@ -52,10 +52,10 @@ func TestBuildResponse(t *testing.T) {
t.Errorf("answers: got %d, want %d", len(resp.Answer), tt.wantAnswers)
}
if tt.wantEdns0 {
- if opt := resp.IsEdns0(); opt == nil {
+ if resp.UDPSize == 0 {
t.Error("expected EDNS0 in response, got none")
- } else if opt.UDPSize() != 4096 {
- t.Errorf("edns0 udp size: got %d, want 4096", opt.UDPSize())
+ } else if resp.UDPSize != 4096 {
+ t.Errorf("edns0 udp size: got %d, want 4096", resp.UDPSize)
}
}
})
@@ -64,9 +64,7 @@ func TestBuildResponse(t *testing.T) {
func TestBuildResponseWithQuery(t *testing.T) {
s := testServer(t)
- // Valid query → must not panic, rcode must be valid
- m := new(dns.Msg)
- m.SetQuestion("example.com.", dns.TypeA)
+ m := dns.NewMsg("example.com.", dns.TypeA)
resp, _ := s.buildResponse(m)
if resp == nil {
t.Fatal("buildResponse returned nil")
@@ -80,7 +78,6 @@ func FuzzBuildResponse(f *testing.F) {
baseCtx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Server{logger: slog.Default(), baseCtx: baseCtx}
- // For fuzz, use a resolver that won't make real network calls
s.resolver = resolver.New(
resolver.WithRootAddresses([]string{"127.0.0.1:1"}),
resolver.WithTimeout(10*time.Millisecond),
@@ -103,14 +100,15 @@ func FuzzBuildResponse(f *testing.F) {
f.Add(seed)
f.Fuzz(func(t *testing.T, data []byte) {
msg := new(dns.Msg)
- if err := msg.Unpack(data); err != nil {
+ msg.Data = data
+ if err := msg.Unpack(); err != nil {
return
}
resp, _ := s.buildResponse(msg)
if resp == nil {
t.Fatal("buildResponse returned nil")
}
- if _, err := resp.Pack(); err != nil {
+ if err := resp.Pack(); err != nil {
t.Errorf("pack failed: %v", err)
}
})