summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/admin.go62
-rw-r--r--internal/server/doh.go70
-rw-r--r--internal/server/handler.go121
-rw-r--r--internal/server/server.go163
-rw-r--r--internal/server/server_test.go2
5 files changed, 354 insertions, 64 deletions
diff --git a/internal/server/admin.go b/internal/server/admin.go
new file mode 100644
index 0000000..46242d4
--- /dev/null
+++ b/internal/server/admin.go
@@ -0,0 +1,62 @@
+package server
+
+import (
+ "fmt"
+ "log/slog"
+ "net/http"
+)
+
+type Admin struct {
+ server *http.Server
+ s *Server
+}
+
+func NewAdmin(listen string, s *Server) *Admin {
+ if listen == "" {
+ return nil
+ }
+
+ mux := http.NewServeMux()
+ a := &Admin{
+ server: &http.Server{
+ Addr: listen,
+ Handler: mux,
+ },
+ s: s,
+ }
+ mux.HandleFunc("/health", a.healthHandler)
+ return a
+}
+
+func (a *Admin) Start() {
+ if a == nil {
+ return
+ }
+ go func() {
+ slog.Info("admin server listening", "addr", a.server.Addr)
+ if err := a.server.ListenAndServe(); err != nil &&
+ err != http.ErrServerClosed {
+ slog.Error("admin server failed", "error", err)
+ }
+ }()
+}
+
+func (a *Admin) Close() error {
+ if a == nil {
+ return nil
+ }
+ return a.server.Close()
+}
+
+func (a *Admin) healthHandler(w http.ResponseWriter, r *http.Request) {
+ if a.s == nil {
+ http.Error(w, "not ready", http.StatusServiceUnavailable)
+ return
+ }
+ if a.s.Ready() {
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprint(w, "ok\n")
+ return
+ }
+ http.Error(w, "not ready", http.StatusServiceUnavailable)
+}
diff --git a/internal/server/doh.go b/internal/server/doh.go
index 2f9dfc0..0feb094 100644
--- a/internal/server/doh.go
+++ b/internal/server/doh.go
@@ -2,13 +2,27 @@ package server
import (
"encoding/base64"
- "codeberg.org/miekg/dns"
"io"
"log/slog"
"net/http"
+
+ "codeberg.org/miekg/dns"
+ "linum/internal/cache"
)
func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
+ clientIP := parseHTTPClientIP(r.RemoteAddr)
+ if !s.isAllowed(clientIP) {
+ slog.Warn("doh query denied by ACL", "client", clientIP)
+ http.Error(w, "refused", http.StatusForbidden)
+ return
+ }
+ if !s.rateLimit(clientIP.String()) {
+ slog.Warn("doh query rate limited", "client", clientIP)
+ http.Error(w, "too many requests", http.StatusTooManyRequests)
+ return
+ }
+
var raw []byte
switch r.Method {
@@ -53,23 +67,55 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) {
return
}
+ q := msg.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
+
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
+ httpWriteMsg(w, s.blockedResponse(msg))
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "blocked", true,
+ )
+ return
+ }
+
+ if s.cache != nil {
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
+ if packed, ok := s.cache.Get(key); ok {
+ packed[0] = byte(msg.ID >> 8)
+ packed[1] = byte(msg.ID)
+ w.Header().Set("Content-Type", "application/dns-message")
+ w.Header().Set("Cache-Control", "no-cache, max-age=0")
+ w.Write(packed)
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "cached", true,
+ )
+ return
+ }
+ }
+
resp, _ := s.buildResponse(msg)
- if err := resp.Pack(); err != nil {
+ httpWriteMsg(w, resp)
+ slog.Debug("doh query served",
+ "qname", qname,
+ "qtype", dns.TypeToString[qtype],
+ "rcode", dns.RcodeToString[resp.Rcode],
+ )
+}
+
+func httpWriteMsg(w http.ResponseWriter, msg *dns.Msg) {
+ if err := msg.Pack(); err != nil {
http.Error(w, "pack response", http.StatusInternalServerError)
return
}
-
w.Header().Set("Content-Type", "application/dns-message")
w.Header().Set("Cache-Control", "no-cache, max-age=0")
- if _, err := w.Write(resp.Data); err != nil {
+ if _, err := w.Write(msg.Data); err != nil {
slog.Error("doh write failed", "err", err)
- return
}
-
- slog.Info("doh query served",
- "qname", msg.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])],
- "rcode", dns.RcodeToString[resp.Rcode],
- "client", r.RemoteAddr,
- )
}
diff --git a/internal/server/handler.go b/internal/server/handler.go
index d0aa705..4516468 100644
--- a/internal/server/handler.go
+++ b/internal/server/handler.go
@@ -4,6 +4,7 @@ import (
"context"
"io"
"log/slog"
+ "net"
"net/netip"
"time"
@@ -14,6 +15,18 @@ import (
)
func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) {
+ clientIP := parseClientIP(w.RemoteAddr())
+ if !s.isAllowed(clientIP) {
+ slog.Warn("query denied by ACL", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+ if !s.rateLimit(clientIP.String()) {
+ slog.Warn("query rate limited", "client", clientIP)
+ s.writeRefused(w, req)
+ return
+ }
+
if len(req.Question) == 0 {
m := new(dns.Msg)
m.Rcode = dns.RcodeFormatError
@@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns
io.Copy(w, m)
return
}
+ q := req.Question[0]
+ qname := q.Header().Name
+ qtype := dns.RRToType(q)
+ qclass := q.Header().Class
- resp, blocked := s.buildResponse(req)
+ if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
+ resp := s.blockedResponse(req)
+ resp.ID = req.ID
+ io.Copy(w, resp)
+ slog.Info("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode],
+ "blocked", true)
+ return
+ }
+ if s.cache != nil {
+ key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
+ if packed, ok := s.cache.Get(key); ok {
+ packed[0] = byte(req.ID >> 8)
+ packed[1] = byte(req.ID)
+ io.Copy(w, &dns.Msg{Data: packed})
+ slog.Debug("query served", "qname", qname, "qtype",
+ dns.TypeToString[qtype], "cached", true)
+ return
+ }
+ }
+ resp, blocked := s.buildResponse(req)
resp.ID = req.ID
resp.Data = nil
if _, err := io.Copy(w, resp); err != nil {
- slog.Error("write response failed",
- "err", err,
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- )
+ slog.Error("write failed", "err", err, "qname", qname)
return
}
- slog.Info("query served",
- "qname", req.Question[0].Header().Name,
- "qtype", dns.TypeToString[dns.RRToType(req.Question[0])],
- "rcode", dns.RcodeToString[resp.Rcode],
- "blocked", blocked,
- )
+ slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype],
+ "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked)
}
-
func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
if len(req.Question) == 0 {
m := new(dns.Msg)
@@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) {
m.Question = req.Question
return m, false
}
+
q := req.Question[0]
qname := q.Header().Name
qtype := dns.RRToType(q)
qclass := q.Header().Class
- if s.blocklist != nil && s.blocklist.IsBlocked(qname) {
- return s.blockedResponse(req), true
- }
-
- if s.cache != nil {
- key := cache.Key{Name: qname, Qtype: qtype, Class: qclass}
- if cached, ok := s.cache.Get(key); ok {
- cached.ID = req.ID
- return cached, false
- }
- }
-
ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second)
defer cancel()
reply, err := s.resolver.Lookup(ctx, qname, qtype)
if err != nil {
slog.Error("resolution failed",
- "err", err,
+ "err", err,
"qname", qname,
"qtype", dns.TypeToString[qtype],
)
@@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg {
})
case dns.TypeAAAA:
m.Answer = append(m.Answer, &dns.AAAA{
- Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
+ Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60},
AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()},
})
}
return m
}
+
+func parseClientIP(addr net.Addr) net.IP {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.IP
+ case *net.TCPAddr:
+ return a.IP
+ }
+ host, _, _ := net.SplitHostPort(addr.String())
+ return net.ParseIP(host)
+}
+
+func parseHTTPClientIP(addr string) net.IP {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return net.ParseIP(addr)
+ }
+ return net.ParseIP(host)
+}
+
+func (s *Server) isAllowed(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ ip = ip4
+ }
+ for _, n := range s.aclNets {
+ if n.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *Server) rateLimit(ip string) bool {
+ if s.rateLimiter == nil {
+ return true
+ }
+ return s.rateLimiter.allow(ip)
+}
+
+func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) {
+ resp := new(dns.Msg)
+ resp.Response = true
+ resp.ID = req.ID
+ resp.Question = req.Question
+ resp.Rcode = dns.RcodeRefused
+ io.Copy(w, resp)
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index 1aa3256..a90e5ac 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,55 +2,95 @@ package server
import (
"context"
+ "fmt"
"log/slog"
+ "net"
"net/http"
+ "sync"
"time"
- "codeberg.org/miekg/dns"
- "linum/internal/resolver"
"linum/internal/blocklist"
"linum/internal/cache"
+ "linum/internal/config"
+ "linum/internal/resolver"
+
+ "codeberg.org/miekg/dns"
)
type Server struct {
- logger *slog.Logger
- resolver *resolver.Resolver
- cache *cache.Cache
- blocklist *blocklist.Blocklist
- udp *dns.Server
- tcp *dns.Server
- doh *http.Server
- baseCtx context.Context
+ resolver *resolver.Resolver
+ cache *cache.Cache
+ blocklist *blocklist.Blocklist
+ logger *slog.Logger
+ cfg config.ServerConfig
+ baseCtx context.Context
+ udp *dns.Server
+ tcp *dns.Server
+ doh *http.Server
+ admin *Admin
+ aclNets []*net.IPNet
+ rateLimiter *rateLimiter
+
+ mu sync.RWMutex
+ upUDP bool
+ upTCP bool
+ upDoH bool
+
cancel context.CancelFunc
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
-r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
+func (s *Server) Ready() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ wantUDP := s.cfg.ListenUDP != ""
+ wantTCP := s.cfg.ListenTCP != ""
+ wantDoH := s.cfg.ListenDOH != ""
+ return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH)
+
+}
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) {
baseCtx, cancel := context.WithCancel(context.Background())
+ s := &Server{
+ logger: logger,
+ resolver: r,
+ cache: c,
+ blocklist: bl,
+ cfg: cfg.Server,
+ baseCtx: baseCtx,
+ cancel: cancel,
+ aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)),
+ rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst),
+ }
- s := &Server{logger: logger, resolver: r, cache: c, blocklist: b,
- baseCtx: baseCtx, cancel: cancel}
+ for _, cidr := range cfg.ACL.Allow {
+ _, ipnet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err)
+ }
+ s.aclNets = append(s.aclNets, ipnet)
+ }
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
if udpAddr != "" {
s.udp = &dns.Server{
- Addr: udpAddr,
- Net: "udp",
- Handler: mux,
- UDPSize: 4096,
- ReadTimeout: 5 * time.Second,
- ReusePort: true,
+ Addr: udpAddr,
+ Net: "udp",
+ Handler: mux,
+ UDPSize: 4096,
+ ReadTimeout: 5 * time.Second,
+ ReusePort: true,
}
}
if tcpAddr != "" {
s.tcp = &dns.Server{
- Addr: tcpAddr,
- Net: "tcp",
- Handler: mux,
- ReadTimeout: 5 * time.Second,
+ Addr: tcpAddr,
+ Net: "tcp",
+ Handler: mux,
+ ReadTimeout: 5 * time.Second,
}
}
@@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
WriteTimeout: 5 * time.Second,
}
}
+
+ if cfg.Admin.Listen != "" {
+ s.admin = NewAdmin(cfg.Admin.Listen, s)
+ }
+
return s, nil
}
func (s *Server) Run(ctx context.Context) error {
+ s.mu.Lock()
+ if s.udp != nil {
+ s.upUDP = true
+ }
+ if s.tcp != nil {
+ s.upTCP = true
+ }
+ if s.doh != nil {
+ s.upDoH = true
+ }
+ s.mu.Unlock()
+
+ if s.admin != nil {
+ s.admin.Start()
+ }
+
errCh := make(chan error, 3)
if s.udp != nil {
go func() {
@@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error {
}
func (s *Server) Close() error {
+ if s.admin != nil {
+ s.admin.Close()
+ }
if s.udp != nil {
s.udp.Shutdown(context.Background())
}
@@ -121,3 +185,54 @@ func (s *Server) Close() error {
}
return nil
}
+
+type rateLimiter struct {
+ mu sync.Mutex
+ rate float64
+ burst float64
+ buckets map[string]*rateBucket
+}
+
+type rateBucket struct {
+ tokens float64
+ last time.Time
+}
+
+func newRateLimiter(qps, burst int) *rateLimiter {
+ if qps <= 0 && burst <= 0 {
+ return nil
+ }
+ if burst <= 0 {
+ burst = qps
+ }
+ return &rateLimiter{
+ rate: float64(qps),
+ burst: float64(burst),
+ buckets: make(map[string]*rateBucket),
+ }
+}
+
+func (rl *rateLimiter) allow(ip string) bool {
+ now := time.Now()
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ b, ok := rl.buckets[ip]
+ if !ok {
+ b = &rateBucket{tokens: rl.burst, last: now}
+ rl.buckets[ip] = b
+ }
+
+ elapsed := now.Sub(b.last).Seconds()
+ b.tokens += elapsed * rl.rate
+ if b.tokens > rl.burst {
+ b.tokens = rl.burst
+ }
+ b.last = now
+
+ if b.tokens < 1 {
+ return false
+ }
+ b.tokens--
+ return true
+}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index 002aba8..2340acd 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -1,8 +1,8 @@
package server
import (
- "log/slog"
"context"
+ "log/slog"
"testing"
"time"