summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
committerradhitya <alif@radhitya.org>2026-06-21 09:48:42 +0700
commitb7359e1d45f505171356bcae3c7d5e2341ecc859 (patch)
treef91d4a4b08ce279d488a76e9b7141e69fc844ea9 /internal/server/server.go
parent2b1f613c42de3861141eb6f93c1740b6937ee183 (diff)
forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go163
1 files changed, 139 insertions, 24 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 1aa3256..a90e5ac 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,55 +2,95 @@ package server
import (
"context"
+ "fmt"
"log/slog"
+ "net"
"net/http"
+ "sync"
"time"
- "codeberg.org/miekg/dns"
- "linum/internal/resolver"
"linum/internal/blocklist"
"linum/internal/cache"
+ "linum/internal/config"
+ "linum/internal/resolver"
+
+ "codeberg.org/miekg/dns"
)
type Server struct {
- logger *slog.Logger
- resolver *resolver.Resolver
- cache *cache.Cache
- blocklist *blocklist.Blocklist
- udp *dns.Server
- tcp *dns.Server
- doh *http.Server
- baseCtx context.Context
+ resolver *resolver.Resolver
+ cache *cache.Cache
+ blocklist *blocklist.Blocklist
+ logger *slog.Logger
+ cfg config.ServerConfig
+ baseCtx context.Context
+ udp *dns.Server
+ tcp *dns.Server
+ doh *http.Server
+ admin *Admin
+ aclNets []*net.IPNet
+ rateLimiter *rateLimiter
+
+ mu sync.RWMutex
+ upUDP bool
+ upTCP bool
+ upDoH bool
+
cancel context.CancelFunc
}
-func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger,
-r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
+func (s *Server) Ready() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ wantUDP := s.cfg.ListenUDP != ""
+ wantTCP := s.cfg.ListenTCP != ""
+ wantDoH := s.cfg.ListenDOH != ""
+ return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH)
+
+}
+func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) {
baseCtx, cancel := context.WithCancel(context.Background())
+ s := &Server{
+ logger: logger,
+ resolver: r,
+ cache: c,
+ blocklist: bl,
+ cfg: cfg.Server,
+ baseCtx: baseCtx,
+ cancel: cancel,
+ aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)),
+ rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst),
+ }
- s := &Server{logger: logger, resolver: r, cache: c, blocklist: b,
- baseCtx: baseCtx, cancel: cancel}
+ for _, cidr := range cfg.ACL.Allow {
+ _, ipnet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err)
+ }
+ s.aclNets = append(s.aclNets, ipnet)
+ }
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
if udpAddr != "" {
s.udp = &dns.Server{
- Addr: udpAddr,
- Net: "udp",
- Handler: mux,
- UDPSize: 4096,
- ReadTimeout: 5 * time.Second,
- ReusePort: true,
+ Addr: udpAddr,
+ Net: "udp",
+ Handler: mux,
+ UDPSize: 4096,
+ ReadTimeout: 5 * time.Second,
+ ReusePort: true,
}
}
if tcpAddr != "" {
s.tcp = &dns.Server{
- Addr: tcpAddr,
- Net: "tcp",
- Handler: mux,
- ReadTimeout: 5 * time.Second,
+ Addr: tcpAddr,
+ Net: "tcp",
+ Handler: mux,
+ ReadTimeout: 5 * time.Second,
}
}
@@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) {
WriteTimeout: 5 * time.Second,
}
}
+
+ if cfg.Admin.Listen != "" {
+ s.admin = NewAdmin(cfg.Admin.Listen, s)
+ }
+
return s, nil
}
func (s *Server) Run(ctx context.Context) error {
+ s.mu.Lock()
+ if s.udp != nil {
+ s.upUDP = true
+ }
+ if s.tcp != nil {
+ s.upTCP = true
+ }
+ if s.doh != nil {
+ s.upDoH = true
+ }
+ s.mu.Unlock()
+
+ if s.admin != nil {
+ s.admin.Start()
+ }
+
errCh := make(chan error, 3)
if s.udp != nil {
go func() {
@@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error {
}
func (s *Server) Close() error {
+ if s.admin != nil {
+ s.admin.Close()
+ }
if s.udp != nil {
s.udp.Shutdown(context.Background())
}
@@ -121,3 +185,54 @@ func (s *Server) Close() error {
}
return nil
}
+
+type rateLimiter struct {
+ mu sync.Mutex
+ rate float64
+ burst float64
+ buckets map[string]*rateBucket
+}
+
+type rateBucket struct {
+ tokens float64
+ last time.Time
+}
+
+func newRateLimiter(qps, burst int) *rateLimiter {
+ if qps <= 0 && burst <= 0 {
+ return nil
+ }
+ if burst <= 0 {
+ burst = qps
+ }
+ return &rateLimiter{
+ rate: float64(qps),
+ burst: float64(burst),
+ buckets: make(map[string]*rateBucket),
+ }
+}
+
+func (rl *rateLimiter) allow(ip string) bool {
+ now := time.Now()
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ b, ok := rl.buckets[ip]
+ if !ok {
+ b = &rateBucket{tokens: rl.burst, last: now}
+ rl.buckets[ip] = b
+ }
+
+ elapsed := now.Sub(b.last).Seconds()
+ b.tokens += elapsed * rl.rate
+ if b.tokens > rl.burst {
+ b.tokens = rl.burst
+ }
+ b.last = now
+
+ if b.tokens < 1 {
+ return false
+ }
+ b.tokens--
+ return true
+}