package server import ( "context" "fmt" "log/slog" "net" "net/http" "sync" "time" "linum/internal/blocklist" "linum/internal/cache" "linum/internal/config" "linum/internal/resolver" "codeberg.org/miekg/dns" ) type Server struct { resolver *resolver.Resolver cache *cache.Cache blocklist *blocklist.Blocklist logger *slog.Logger cfg config.ServerConfig baseCtx context.Context udp *dns.Server tcp *dns.Server doh *http.Server admin *Admin aclNets []*net.IPNet rateLimiter *rateLimiter mu sync.RWMutex upUDP bool upTCP bool upDoH bool cancel context.CancelFunc } func (s *Server) Ready() bool { s.mu.RLock() defer s.mu.RUnlock() wantUDP := s.cfg.ListenUDP != "" wantTCP := s.cfg.ListenTCP != "" wantDoH := s.cfg.ListenDOH != "" return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH) } func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) { baseCtx, cancel := context.WithCancel(context.Background()) s := &Server{ logger: logger, resolver: r, cache: c, blocklist: bl, cfg: cfg.Server, baseCtx: baseCtx, cancel: cancel, aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)), rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst), } for _, cidr := range cfg.ACL.Allow { _, ipnet, err := net.ParseCIDR(cidr) if err != nil { cancel() return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err) } s.aclNets = append(s.aclNets, ipnet) } mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) if udpAddr != "" { s.udp = &dns.Server{ Addr: udpAddr, Net: "udp", Handler: mux, UDPSize: 4096, ReadTimeout: 5 * time.Second, ReusePort: true, } } if tcpAddr != "" { s.tcp = &dns.Server{ Addr: tcpAddr, Net: "tcp", Handler: mux, ReadTimeout: 5 * time.Second, } } if dohAddr != "" { dohMux := http.NewServeMux() dohMux.HandleFunc("/dns-query", s.dohHandler) s.doh = &http.Server{ Addr: dohAddr, Handler: dohMux, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, } } if cfg.Admin.Listen != "" { s.admin = NewAdmin(cfg.Admin.Listen, s) } return s, nil } func (s *Server) Run(ctx context.Context) error { s.mu.Lock() if s.udp != nil { s.upUDP = true } if s.tcp != nil { s.upTCP = true } if s.doh != nil { s.upDoH = true } s.mu.Unlock() if s.admin != nil { s.admin.Start() } errCh := make(chan error, 3) if s.udp != nil { go func() { s.logger.Info("udp listener active", "addr", s.udp.Addr) errCh <- s.udp.ListenAndServe() }() } if s.tcp != nil { go func() { s.logger.Info("tcp listener active", "addr", s.tcp.Addr) errCh <- s.tcp.ListenAndServe() }() } if s.doh != nil { go func() { s.logger.Info("doh listener active", "addr", s.doh.Addr) errCh <- s.doh.ListenAndServe() }() } select { case <-ctx.Done(): shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if s.udp != nil { s.udp.Shutdown(shutdownCtx) } if s.tcp != nil { s.tcp.Shutdown(shutdownCtx) } if s.doh != nil { s.doh.Shutdown(shutdownCtx) } return ctx.Err() case err := <-errCh: return err } } func (s *Server) Close() error { if s.admin != nil { s.admin.Close() } if s.udp != nil { s.udp.Shutdown(context.Background()) } if s.tcp != nil { s.tcp.Shutdown(context.Background()) } if s.doh != nil { s.doh.Close() } return nil } type rateLimiter struct { mu sync.Mutex rate float64 burst float64 buckets map[string]*rateBucket } type rateBucket struct { tokens float64 last time.Time } func newRateLimiter(qps, burst int) *rateLimiter { if qps <= 0 && burst <= 0 { return nil } if burst <= 0 { burst = qps } return &rateLimiter{ rate: float64(qps), burst: float64(burst), buckets: make(map[string]*rateBucket), } } func (rl *rateLimiter) allow(ip string) bool { now := time.Now() rl.mu.Lock() defer rl.mu.Unlock() b, ok := rl.buckets[ip] if !ok { b = &rateBucket{tokens: rl.burst, last: now} rl.buckets[ip] = b } elapsed := now.Sub(b.last).Seconds() b.tokens += elapsed * rl.rate if b.tokens > rl.burst { b.tokens = rl.burst } b.last = now if b.tokens < 1 { return false } b.tokens-- return true }