diff options
Diffstat (limited to 'internal/server/server.go')
| -rw-r--r-- | internal/server/server.go | 163 |
1 files changed, 139 insertions, 24 deletions
diff --git a/internal/server/server.go b/internal/server/server.go index 1aa3256..a90e5ac 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,55 +2,95 @@ package server import ( "context" + "fmt" "log/slog" + "net" "net/http" + "sync" "time" - "codeberg.org/miekg/dns" - "linum/internal/resolver" "linum/internal/blocklist" "linum/internal/cache" + "linum/internal/config" + "linum/internal/resolver" + + "codeberg.org/miekg/dns" ) type Server struct { - logger *slog.Logger - resolver *resolver.Resolver - cache *cache.Cache - blocklist *blocklist.Blocklist - udp *dns.Server - tcp *dns.Server - doh *http.Server - baseCtx context.Context + resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist + logger *slog.Logger + cfg config.ServerConfig + baseCtx context.Context + udp *dns.Server + tcp *dns.Server + doh *http.Server + admin *Admin + aclNets []*net.IPNet + rateLimiter *rateLimiter + + mu sync.RWMutex + upUDP bool + upTCP bool + upDoH bool + cancel context.CancelFunc } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, -r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { +func (s *Server) Ready() bool { + s.mu.RLock() + defer s.mu.RUnlock() + wantUDP := s.cfg.ListenUDP != "" + wantTCP := s.cfg.ListenTCP != "" + wantDoH := s.cfg.ListenDOH != "" + return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH) + +} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) { baseCtx, cancel := context.WithCancel(context.Background()) + s := &Server{ + logger: logger, + resolver: r, + cache: c, + blocklist: bl, + cfg: cfg.Server, + baseCtx: baseCtx, + cancel: cancel, + aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)), + rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst), + } - s := &Server{logger: logger, resolver: r, cache: c, blocklist: b, - baseCtx: baseCtx, cancel: cancel} + for _, cidr := range cfg.ACL.Allow { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + cancel() + return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err) + } + s.aclNets = append(s.aclNets, ipnet) + } mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) if udpAddr != "" { s.udp = &dns.Server{ - Addr: udpAddr, - Net: "udp", - Handler: mux, - UDPSize: 4096, - ReadTimeout: 5 * time.Second, - ReusePort: true, + Addr: udpAddr, + Net: "udp", + Handler: mux, + UDPSize: 4096, + ReadTimeout: 5 * time.Second, + ReusePort: true, } } if tcpAddr != "" { s.tcp = &dns.Server{ - Addr: tcpAddr, - Net: "tcp", - Handler: mux, - ReadTimeout: 5 * time.Second, + Addr: tcpAddr, + Net: "tcp", + Handler: mux, + ReadTimeout: 5 * time.Second, } } @@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { WriteTimeout: 5 * time.Second, } } + + if cfg.Admin.Listen != "" { + s.admin = NewAdmin(cfg.Admin.Listen, s) + } + return s, nil } func (s *Server) Run(ctx context.Context) error { + s.mu.Lock() + if s.udp != nil { + s.upUDP = true + } + if s.tcp != nil { + s.upTCP = true + } + if s.doh != nil { + s.upDoH = true + } + s.mu.Unlock() + + if s.admin != nil { + s.admin.Start() + } + errCh := make(chan error, 3) if s.udp != nil { go func() { @@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) Close() error { + if s.admin != nil { + s.admin.Close() + } if s.udp != nil { s.udp.Shutdown(context.Background()) } @@ -121,3 +185,54 @@ func (s *Server) Close() error { } return nil } + +type rateLimiter struct { + mu sync.Mutex + rate float64 + burst float64 + buckets map[string]*rateBucket +} + +type rateBucket struct { + tokens float64 + last time.Time +} + +func newRateLimiter(qps, burst int) *rateLimiter { + if qps <= 0 && burst <= 0 { + return nil + } + if burst <= 0 { + burst = qps + } + return &rateLimiter{ + rate: float64(qps), + burst: float64(burst), + buckets: make(map[string]*rateBucket), + } +} + +func (rl *rateLimiter) allow(ip string) bool { + now := time.Now() + rl.mu.Lock() + defer rl.mu.Unlock() + + b, ok := rl.buckets[ip] + if !ok { + b = &rateBucket{tokens: rl.burst, last: now} + rl.buckets[ip] = b + } + + elapsed := now.Sub(b.last).Seconds() + b.tokens += elapsed * rl.rate + if b.tokens > rl.burst { + b.tokens = rl.burst + } + b.last = now + + if b.tokens < 1 { + return false + } + b.tokens-- + return true +} |
