From b7359e1d45f505171356bcae3c7d5e2341ecc859 Mon Sep 17 00:00:00 2001 From: radhitya Date: Sun, 21 Jun 2026 09:48:42 +0700 Subject: forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply --- internal/server/admin.go | 62 ++++++++++++++++ internal/server/doh.go | 70 +++++++++++++++--- internal/server/handler.go | 121 +++++++++++++++++++++++------- internal/server/server.go | 163 +++++++++++++++++++++++++++++++++++------ internal/server/server_test.go | 2 +- 5 files changed, 354 insertions(+), 64 deletions(-) create mode 100644 internal/server/admin.go (limited to 'internal/server') diff --git a/internal/server/admin.go b/internal/server/admin.go new file mode 100644 index 0000000..46242d4 --- /dev/null +++ b/internal/server/admin.go @@ -0,0 +1,62 @@ +package server + +import ( + "fmt" + "log/slog" + "net/http" +) + +type Admin struct { + server *http.Server + s *Server +} + +func NewAdmin(listen string, s *Server) *Admin { + if listen == "" { + return nil + } + + mux := http.NewServeMux() + a := &Admin{ + server: &http.Server{ + Addr: listen, + Handler: mux, + }, + s: s, + } + mux.HandleFunc("/health", a.healthHandler) + return a +} + +func (a *Admin) Start() { + if a == nil { + return + } + go func() { + slog.Info("admin server listening", "addr", a.server.Addr) + if err := a.server.ListenAndServe(); err != nil && + err != http.ErrServerClosed { + slog.Error("admin server failed", "error", err) + } + }() +} + +func (a *Admin) Close() error { + if a == nil { + return nil + } + return a.server.Close() +} + +func (a *Admin) healthHandler(w http.ResponseWriter, r *http.Request) { + if a.s == nil { + http.Error(w, "not ready", http.StatusServiceUnavailable) + return + } + if a.s.Ready() { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok\n") + return + } + http.Error(w, "not ready", http.StatusServiceUnavailable) +} diff --git a/internal/server/doh.go b/internal/server/doh.go index 2f9dfc0..0feb094 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -2,13 +2,27 @@ package server import ( "encoding/base64" - "codeberg.org/miekg/dns" "io" "log/slog" "net/http" + + "codeberg.org/miekg/dns" + "linum/internal/cache" ) func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { + clientIP := parseHTTPClientIP(r.RemoteAddr) + if !s.isAllowed(clientIP) { + slog.Warn("doh query denied by ACL", "client", clientIP) + http.Error(w, "refused", http.StatusForbidden) + return + } + if !s.rateLimit(clientIP.String()) { + slog.Warn("doh query rate limited", "client", clientIP) + http.Error(w, "too many requests", http.StatusTooManyRequests) + return + } + var raw []byte switch r.Method { @@ -53,23 +67,55 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } + q := msg.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class + + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { + httpWriteMsg(w, s.blockedResponse(msg)) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "blocked", true, + ) + return + } + + if s.cache != nil { + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} + if packed, ok := s.cache.Get(key); ok { + packed[0] = byte(msg.ID >> 8) + packed[1] = byte(msg.ID) + w.Header().Set("Content-Type", "application/dns-message") + w.Header().Set("Cache-Control", "no-cache, max-age=0") + w.Write(packed) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "cached", true, + ) + return + } + } + resp, _ := s.buildResponse(msg) - if err := resp.Pack(); err != nil { + httpWriteMsg(w, resp) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "rcode", dns.RcodeToString[resp.Rcode], + ) +} + +func httpWriteMsg(w http.ResponseWriter, msg *dns.Msg) { + if err := msg.Pack(); err != nil { http.Error(w, "pack response", http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") - if _, err := w.Write(resp.Data); err != nil { + if _, err := w.Write(msg.Data); err != nil { slog.Error("doh write failed", "err", err) - return } - - slog.Info("doh query served", - "qname", msg.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])], - "rcode", dns.RcodeToString[resp.Rcode], - "client", r.RemoteAddr, - ) } diff --git a/internal/server/handler.go b/internal/server/handler.go index d0aa705..4516468 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "net" "net/netip" "time" @@ -14,6 +15,18 @@ import ( ) func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + clientIP := parseClientIP(w.RemoteAddr()) + if !s.isAllowed(clientIP) { + slog.Warn("query denied by ACL", "client", clientIP) + s.writeRefused(w, req) + return + } + if !s.rateLimit(clientIP.String()) { + slog.Warn("query rate limited", "client", clientIP) + s.writeRefused(w, req) + return + } + if len(req.Question) == 0 { m := new(dns.Msg) m.Rcode = dns.RcodeFormatError @@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns io.Copy(w, m) return } + q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - resp, blocked := s.buildResponse(req) + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { + resp := s.blockedResponse(req) + resp.ID = req.ID + io.Copy(w, resp) + slog.Info("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], + "blocked", true) + return + } + if s.cache != nil { + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} + if packed, ok := s.cache.Get(key); ok { + packed[0] = byte(req.ID >> 8) + packed[1] = byte(req.ID) + io.Copy(w, &dns.Msg{Data: packed}) + slog.Debug("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "cached", true) + return + } + } + resp, blocked := s.buildResponse(req) resp.ID = req.ID resp.Data = nil if _, err := io.Copy(w, resp); err != nil { - slog.Error("write response failed", - "err", err, - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - ) + slog.Error("write failed", "err", err, "qname", qname) return } - slog.Info("query served", - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - "rcode", dns.RcodeToString[resp.Rcode], - "blocked", blocked, - ) + slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype], + "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked) } - func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { m := new(dns.Msg) @@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { m.Question = req.Question return m, false } + q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(qname) { - return s.blockedResponse(req), true - } - - if s.cache != nil { - key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} - if cached, ok := s.cache.Get(key); ok { - cached.ID = req.ID - return cached, false - } - } - ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", qname, "qtype", dns.TypeToString[qtype], ) @@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m } + +func parseClientIP(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + } + host, _, _ := net.SplitHostPort(addr.String()) + return net.ParseIP(host) +} + +func parseHTTPClientIP(addr string) net.IP { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return net.ParseIP(addr) + } + return net.ParseIP(host) +} + +func (s *Server) isAllowed(ip net.IP) bool { + if ip == nil { + return false + } + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + for _, n := range s.aclNets { + if n.Contains(ip) { + return true + } + } + return false +} + +func (s *Server) rateLimit(ip string) bool { + if s.rateLimiter == nil { + return true + } + return s.rateLimiter.allow(ip) +} + +func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) { + resp := new(dns.Msg) + resp.Response = true + resp.ID = req.ID + resp.Question = req.Question + resp.Rcode = dns.RcodeRefused + io.Copy(w, resp) +} diff --git a/internal/server/server.go b/internal/server/server.go index 1aa3256..a90e5ac 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,55 +2,95 @@ package server import ( "context" + "fmt" "log/slog" + "net" "net/http" + "sync" "time" - "codeberg.org/miekg/dns" - "linum/internal/resolver" "linum/internal/blocklist" "linum/internal/cache" + "linum/internal/config" + "linum/internal/resolver" + + "codeberg.org/miekg/dns" ) type Server struct { - logger *slog.Logger - resolver *resolver.Resolver - cache *cache.Cache - blocklist *blocklist.Blocklist - udp *dns.Server - tcp *dns.Server - doh *http.Server - baseCtx context.Context + resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist + logger *slog.Logger + cfg config.ServerConfig + baseCtx context.Context + udp *dns.Server + tcp *dns.Server + doh *http.Server + admin *Admin + aclNets []*net.IPNet + rateLimiter *rateLimiter + + mu sync.RWMutex + upUDP bool + upTCP bool + upDoH bool + cancel context.CancelFunc } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, -r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { +func (s *Server) Ready() bool { + s.mu.RLock() + defer s.mu.RUnlock() + wantUDP := s.cfg.ListenUDP != "" + wantTCP := s.cfg.ListenTCP != "" + wantDoH := s.cfg.ListenDOH != "" + return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH) + +} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) { baseCtx, cancel := context.WithCancel(context.Background()) + s := &Server{ + logger: logger, + resolver: r, + cache: c, + blocklist: bl, + cfg: cfg.Server, + baseCtx: baseCtx, + cancel: cancel, + aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)), + rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst), + } - s := &Server{logger: logger, resolver: r, cache: c, blocklist: b, - baseCtx: baseCtx, cancel: cancel} + for _, cidr := range cfg.ACL.Allow { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + cancel() + return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err) + } + s.aclNets = append(s.aclNets, ipnet) + } mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) if udpAddr != "" { s.udp = &dns.Server{ - Addr: udpAddr, - Net: "udp", - Handler: mux, - UDPSize: 4096, - ReadTimeout: 5 * time.Second, - ReusePort: true, + Addr: udpAddr, + Net: "udp", + Handler: mux, + UDPSize: 4096, + ReadTimeout: 5 * time.Second, + ReusePort: true, } } if tcpAddr != "" { s.tcp = &dns.Server{ - Addr: tcpAddr, - Net: "tcp", - Handler: mux, - ReadTimeout: 5 * time.Second, + Addr: tcpAddr, + Net: "tcp", + Handler: mux, + ReadTimeout: 5 * time.Second, } } @@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { WriteTimeout: 5 * time.Second, } } + + if cfg.Admin.Listen != "" { + s.admin = NewAdmin(cfg.Admin.Listen, s) + } + return s, nil } func (s *Server) Run(ctx context.Context) error { + s.mu.Lock() + if s.udp != nil { + s.upUDP = true + } + if s.tcp != nil { + s.upTCP = true + } + if s.doh != nil { + s.upDoH = true + } + s.mu.Unlock() + + if s.admin != nil { + s.admin.Start() + } + errCh := make(chan error, 3) if s.udp != nil { go func() { @@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) Close() error { + if s.admin != nil { + s.admin.Close() + } if s.udp != nil { s.udp.Shutdown(context.Background()) } @@ -121,3 +185,54 @@ func (s *Server) Close() error { } return nil } + +type rateLimiter struct { + mu sync.Mutex + rate float64 + burst float64 + buckets map[string]*rateBucket +} + +type rateBucket struct { + tokens float64 + last time.Time +} + +func newRateLimiter(qps, burst int) *rateLimiter { + if qps <= 0 && burst <= 0 { + return nil + } + if burst <= 0 { + burst = qps + } + return &rateLimiter{ + rate: float64(qps), + burst: float64(burst), + buckets: make(map[string]*rateBucket), + } +} + +func (rl *rateLimiter) allow(ip string) bool { + now := time.Now() + rl.mu.Lock() + defer rl.mu.Unlock() + + b, ok := rl.buckets[ip] + if !ok { + b = &rateBucket{tokens: rl.burst, last: now} + rl.buckets[ip] = b + } + + elapsed := now.Sub(b.last).Seconds() + b.tokens += elapsed * rl.rate + if b.tokens > rl.burst { + b.tokens = rl.burst + } + b.last = now + + if b.tokens < 1 { + return false + } + b.tokens-- + return true +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 002aba8..2340acd 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,8 +1,8 @@ package server import ( - "log/slog" "context" + "log/slog" "testing" "time" -- cgit v1.2.3