From b7359e1d45f505171356bcae3c7d5e2341ecc859 Mon Sep 17 00:00:00 2001 From: radhitya Date: Sun, 21 Jun 2026 09:48:42 +0700 Subject: forward mode, cache opt, ACL, rate limit, admin/health, systemd, fix UDP reply --- internal/blocklist/blocklist.go | 16 ++-- internal/blocklist/blocklist_test.go | 2 +- internal/cache/cache.go | 159 +++++++--------------------------- internal/cache/cache_test.go | 20 +++-- internal/config/config.go | 63 +++++++++++--- internal/resolver/resolver.go | 14 +-- internal/resolver/resolver_test.go | 2 +- internal/resolver/root.go | 2 +- internal/server/admin.go | 62 +++++++++++++ internal/server/doh.go | 70 ++++++++++++--- internal/server/handler.go | 121 ++++++++++++++++++++------ internal/server/server.go | 163 +++++++++++++++++++++++++++++------ internal/server/server_test.go | 2 +- 13 files changed, 466 insertions(+), 230 deletions(-) create mode 100644 internal/server/admin.go (limited to 'internal') diff --git a/internal/blocklist/blocklist.go b/internal/blocklist/blocklist.go index f6e0592..8925299 100644 --- a/internal/blocklist/blocklist.go +++ b/internal/blocklist/blocklist.go @@ -2,9 +2,9 @@ package blocklist import ( "bufio" + "fmt" "net/http" "os" - "fmt" "strings" "sync" "sync/atomic" @@ -18,19 +18,19 @@ const ( ) type Blocklist struct { - mu sync.RWMutex - blocked *trie + mu sync.RWMutex + blocked *trie exceptions *trie - response ResponseAction + response ResponseAction TotalRules int32 - Hits int64 + Hits int64 } func New(action ResponseAction) *Blocklist { return &Blocklist{ - blocked: newTrie(), + blocked: newTrie(), exceptions: newTrie(), - response: action, + response: action, } } @@ -128,8 +128,6 @@ func (b *Blocklist) LoadURL(url string) error { return b.load(bufio.NewScanner(resp.Body)) } - - func splitDomain(domain string) []string { domain = strings.TrimSuffix(domain, ".") return strings.Split(domain, ".") diff --git a/internal/blocklist/blocklist_test.go b/internal/blocklist/blocklist_test.go index b6d4e83..874744c 100644 --- a/internal/blocklist/blocklist_test.go +++ b/internal/blocklist/blocklist_test.go @@ -7,7 +7,7 @@ import ( ) func TestBlockZeroIP(t *testing.T) { -dir := t.TempDir() + dir := t.TempDir() p := filepath.Join(dir, "block.txt") os.WriteFile(p, []byte("||example.com^\n0.0.0.0 doubleclick.net\n"), 0644) diff --git a/internal/cache/cache.go b/internal/cache/cache.go index bb35b8e..b0a0959 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -12,18 +12,18 @@ import ( ) type Key struct { - Name string + Name string Qtype uint16 Class uint16 } type entry struct { - msg *dns.Msg + packed []byte storedAt time.Time - ttl time.Duration + ttl time.Duration } -func (e *entry) expired() bool { +func (e *entry) expired() bool { return time.Since(e.storedAt) >= e.ttl } @@ -36,20 +36,20 @@ func (e *entry) remaining() time.Duration { } type Cache struct { - mu sync.RWMutex + mu sync.RWMutex entries map[Key]*entry maxSize int - db *sql.DB - dbCh chan dbWrite - wg sync.WaitGroup - hits int64 - misses int64 + db *sql.DB + dbCh chan dbWrite + wg sync.WaitGroup + hits int64 + misses int64 evicted int64 - stopCh chan struct{} + stopCh chan struct{} } type dbWrite struct { key Key - e *entry + e *entry } func NewCache(maxSize int, dbPath string) (*Cache, error) { @@ -59,7 +59,7 @@ func NewCache(maxSize int, dbPath string) (*Cache, error) { c := &Cache{ entries: make(map[Key]*entry), maxSize: maxSize, - stopCh: make(chan struct{}), + stopCh: make(chan struct{}), } if dbPath != "" { @@ -108,7 +108,7 @@ func (c *Cache) Stop() { } } -func (c *Cache) Get(key Key) (*dns.Msg, bool) { +func (c *Cache) Get(key Key) ([]byte, bool) { c.mu.RLock() e, ok := c.entries[key] if !ok || e.expired() { @@ -119,10 +119,11 @@ func (c *Cache) Get(key Key) (*dns.Msg, bool) { c.mu.RUnlock() atomic.AddInt64(&c.hits, 1) - msg := deepCopyMsg(e.msg) - remaining := e.remaining() - adjustTTL(msg, e.ttl, remaining) - return msg, true + cp := make([]byte, len(e.packed)) + copy(cp, e.packed) + cp[0] = 0 + cp[1] = 0 + return cp, true } func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { @@ -133,10 +134,16 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { ttl = 60 * time.Second } + if err := msg.Pack(); err != nil { + return + } + cp := make([]byte, len(msg.Data)) + copy(cp, msg.Data) + e := &entry{ - msg: deepCopyMsg(msg), + packed: cp, storedAt: time.Now(), - ttl: ttl, + ttl: ttl, } c.mu.Lock() @@ -156,8 +163,8 @@ func (c *Cache) Set(key Key, msg *dns.Msg, ttl time.Duration) { func (c *Cache) Stats() (int64, int64, int64) { return atomic.LoadInt64(&c.hits), - atomic.LoadInt64(&c.misses), - atomic.LoadInt64(&c.evicted) + atomic.LoadInt64(&c.misses), + atomic.LoadInt64(&c.evicted) } func (c *Cache) Len() int { @@ -228,15 +235,9 @@ func (c *Cache) writeToDB(key Key, e *entry) { if c.db == nil { return } - err := e.msg.Pack() - if err != nil { - return - } - data := e.msg.Data - _, err = c.db.Exec( + _, err := c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) - VALUES (?, ?, ?, ?, ?, ?)`, - key.Name, key.Qtype, key.Class, data, + VALUES (?,?,?,?,?,?)`, key.Name, key.Qtype, key.Class, e.packed, e.storedAt.UnixNano(), int64(e.ttl), ) if err != nil { @@ -269,14 +270,8 @@ func (c *Cache) loadFromDB() { continue } - msg := new(dns.Msg) - msg.Data = data - if err := msg.Unpack(); err != nil { - continue - } - e := &entry{ - msg: msg, + packed: data, storedAt: time.Unix(0, storedAtNano), ttl: time.Duration(ttlNs), } @@ -289,7 +284,6 @@ func (c *Cache) loadFromDB() { } } - func computeTTL(msg *dns.Msg) time.Duration { if msg.Rcode == dns.RcodeNameError && len(msg.Answer) == 0 { for _, rr := range msg.Ns { @@ -332,92 +326,3 @@ func computeTTL(msg *dns.Msg) time.Duration { } return 0 } - -func adjustTTL(msg *dns.Msg, originalTTL, remaining time.Duration) { - ratio := float64(remaining) / float64(originalTTL) - for _, rr := range msg.Answer { - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } - for _, rr := range msg.Ns { - if dns.RRToType(rr) == dns.TypeOPT { - continue - } - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } - for _, rr := range msg.Extra { - if dns.RRToType(rr) == dns.TypeOPT { - continue - } - rr.Header().TTL = applyRatio(rr.Header().TTL, ratio) - } -} - -func applyRatio(ttl uint32, ratio float64) uint32 { - return uint32(float64(ttl) * ratio) -} - -func deepCopyMsg(msg *dns.Msg) *dns.Msg { - cp := new(dns.Msg) - cp.Response = msg.Response - cp.ID = msg.ID - cp.Opcode = msg.Opcode - cp.Authoritative = msg.Authoritative - cp.Truncated = msg.Truncated - cp.RecursionDesired = msg.RecursionDesired - cp.RecursionAvailable = msg.RecursionAvailable - cp.Rcode = msg.Rcode - cp.UDPSize = msg.UDPSize - cp.Question = copyRRSlice(msg.Question) - cp.Answer = copyRRSlice(msg.Answer) - cp.Ns = copyRRSlice(msg.Ns) - cp.Extra = copyRRSlice(msg.Extra) - return cp -} - -func copyRRSlice(src []dns.RR) []dns.RR { - if src == nil { - return nil - } - dst := make([]dns.RR, len(src)) - for i, rr := range src { - dst[i] = copyRR(rr) - } - return dst -} - -func copyRR(rr dns.RR) dns.RR { - if rr == nil { - return nil - } - switch v := rr.(type) { - case *dns.A: - cp := *v - return &cp - case *dns.AAAA: - cp := *v - return &cp - case *dns.NS: - cp := *v - return &cp - case *dns.CNAME: - cp := *v - return &cp - case *dns.SOA: - cp := *v - return &cp - case *dns.MX: - cp := *v - return &cp - case *dns.TXT: - cp := *v - return &cp - case *dns.SRV: - cp := *v - return &cp - case *dns.PTR: - cp := *v - return &cp - default: - return rr - } -} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 33e1d30..aa0aeb0 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -26,14 +26,18 @@ func TestSetGet(t *testing.T) { key := Key{Name: "example.com.", Qtype: dns.TypeA, Class: dns.ClassINET} c.Set(key, msg, 300*time.Second) - got, ok := c.Get(key) + packed, ok := c.Get(key) if !ok { t.Fatal("expected cache hit") } - if len(got.Answer) != 1 { - t.Fatalf("expected 1 answer, got %d", len(got.Answer)) + msg.Data = packed + if err := msg.Unpack(); err != nil { + t.Fatal(err) + } + if len(msg.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(msg.Answer)) } - a, _ := got.Answer[0].(*dns.A) + a, _ := msg.Answer[0].(*dns.A) if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } @@ -130,11 +134,15 @@ func TestSQLitePersistence(t *testing.T) { } defer c2.Stop() - got, ok := c2.Get(key) + packed, ok := c2.Get(key) if !ok { t.Fatal("expected cache hit from SQLite load") } - a, _ := got.Answer[0].(*dns.A) + msg.Data = packed + if err := msg.Unpack(); err != nil { + t.Fatal(err) + } + a, _ := msg.Answer[0].(*dns.A) if a.A.Addr != netip.AddrFrom4([4]byte{1, 2, 3, 4}) { t.Errorf("IP = %s, want 1.2.3.4", a.A.Addr) } diff --git a/internal/config/config.go b/internal/config/config.go index 1fa8069..b2c88ee 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,11 +8,13 @@ import ( ) type Config struct { - Server ServerConfig `toml:"server"` - Cache CacheConfig `toml:"cache"` - Resolver ResolverConfig `toml:"resolver"` + Server ServerConfig `toml:"server"` + Cache CacheConfig `toml:"cache"` + Resolver ResolverConfig `toml:"resolver"` Blocklist BlocklistConfig `toml:"blocklist"` - Log LogConfig `toml:"log"` + Admin AdminConfig `toml:"admin"` + ACL ACLConfig `toml:"acl"` + Log LogConfig `toml:"log"` } type ServerConfig struct { @@ -27,10 +29,10 @@ type CacheConfig struct { } type ResolverConfig struct { - Mode string `toml:"mode"` - Timeout string `toml:"timeout"` - MaxDelegations int `toml:"max_delegations"` - Forwarders []string `toml:"forwarders"` + Mode string `toml:"mode"` + Timeout string `toml:"timeout"` + MaxDelegations int `toml:"max_delegations"` + Forwarders []string `toml:"forwarders"` } type BlocklistConfig struct { @@ -44,13 +46,23 @@ type LogConfig struct { } type CLIFlags struct { - Config string - LogLevel string + Config string + LogLevel string ListenUDP string ListenTCP string ListenDOH string } +type ACLConfig struct { + Allow []string `toml:"allow"` + RateLimitQPS int `toml:"rate_limit_qps"` + RateLimitBurst int `toml:"rate_limit_burst"` +} + +type AdminConfig struct { + Listen string `toml:"listen"` +} + func ParseFlags() CLIFlags { var f CLIFlags flag.StringVar(&f.Config, "config", "linum.toml", "path to config file") @@ -62,7 +74,7 @@ func ParseFlags() CLIFlags { return f } -func Default() Config{ +func Default() Config { return Config{ Server: ServerConfig{ ListenUDP: ":5353", @@ -73,8 +85,8 @@ func Default() Config{ MaxEntries: 100000, }, Resolver: ResolverConfig{ - Mode: "recursive", - Timeout: "2s", + Mode: "recursive", + Timeout: "2s", MaxDelegations: 30, }, Blocklist: BlocklistConfig{ @@ -83,6 +95,14 @@ func Default() Config{ Log: LogConfig{ Level: "info", }, + Admin: AdminConfig{ + Listen: "127.0.0.1:8080", + }, + ACL: ACLConfig{ + Allow: []string{}, + RateLimitQPS: 50, + RateLimitBurst: 10, + }, } } @@ -132,6 +152,18 @@ func Merge(dst, src Config) Config { if src.Log.Level != "" { dst.Log.Level = src.Log.Level } + if src.ACL.Allow != nil { + dst.ACL.Allow = src.ACL.Allow + } + if src.ACL.RateLimitQPS != 0 { + dst.ACL.RateLimitQPS = src.ACL.RateLimitQPS + } + if src.ACL.RateLimitBurst != 0 { + dst.ACL.RateLimitBurst = src.ACL.RateLimitBurst + } + if src.Admin.Listen != "" { + dst.Admin.Listen = src.Admin.Listen + } return dst } @@ -157,10 +189,13 @@ func (c Config) Validate() error { default: return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response) } + switch c.ACL.Allow { + default: + } switch c.Resolver.Mode { case "recursive", "forward", "": - // nothing happened lol + // nothing happened lol default: return fmt.Errorf("invalid resolver mode %q (recursive or forward)", c.Resolver.Mode) } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 1e4f1da..c7a0694 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -20,8 +20,8 @@ type Resolver struct { maxDelegations int timeout time.Duration retries int - forwarders []string - client *dns.Client + forwarders []string + client *dns.Client } type Option func(*Resolver) @@ -77,7 +77,7 @@ func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dn reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true) if err != nil { return nil, fmt.Errorf("forward %s %s: %w", - qname, dns.TypeToString[qtype],err) + qname, dns.TypeToString[qtype], err) } return reply, nil } @@ -94,10 +94,10 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn return nil, ErrNoServers } reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, - false) + false) if err != nil { return nil, fmt.Errorf("resolve %s %s: %w", - qname, dns.TypeToString[qtype], err) + qname, dns.TypeToString[qtype], err) } switch { case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0: @@ -248,7 +248,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err } func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string, -qname string, qtype uint16, rd bool) (*dns.Msg, error) { + qname string, qtype uint16, rd bool) (*dns.Msg, error) { msg := dns.NewMsg(qname, qtype) msg.UDPSize = 4096 @@ -256,7 +256,7 @@ qname string, qtype uint16, rd bool) (*dns.Msg, error) { type result struct { reply *dns.Msg - err error + err error } ch := make(chan result, len(servers)) diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index 54f727e..daa8a98 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -116,7 +116,7 @@ func TestNextServersWithGlue(t *testing.T) { A: rdata.A{Addr: netip.MustParseAddr("192.0.2.1")}, }) msg.Extra = append(msg.Extra, &dns.AAAA{ - Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300}, + Hdr: dns.Header{Name: "ns1.example.com.", Class: dns.ClassINET, TTL: 300}, AAAA: rdata.AAAA{Addr: netip.MustParseAddr("2001:db8::1")}, }) diff --git a/internal/resolver/root.go b/internal/resolver/root.go index 0557dac..7c5244a 100644 --- a/internal/resolver/root.go +++ b/internal/resolver/root.go @@ -1,8 +1,8 @@ package resolver import ( - _ "embed" "codeberg.org/miekg/dns" + _ "embed" "strings" ) diff --git a/internal/server/admin.go b/internal/server/admin.go new file mode 100644 index 0000000..46242d4 --- /dev/null +++ b/internal/server/admin.go @@ -0,0 +1,62 @@ +package server + +import ( + "fmt" + "log/slog" + "net/http" +) + +type Admin struct { + server *http.Server + s *Server +} + +func NewAdmin(listen string, s *Server) *Admin { + if listen == "" { + return nil + } + + mux := http.NewServeMux() + a := &Admin{ + server: &http.Server{ + Addr: listen, + Handler: mux, + }, + s: s, + } + mux.HandleFunc("/health", a.healthHandler) + return a +} + +func (a *Admin) Start() { + if a == nil { + return + } + go func() { + slog.Info("admin server listening", "addr", a.server.Addr) + if err := a.server.ListenAndServe(); err != nil && + err != http.ErrServerClosed { + slog.Error("admin server failed", "error", err) + } + }() +} + +func (a *Admin) Close() error { + if a == nil { + return nil + } + return a.server.Close() +} + +func (a *Admin) healthHandler(w http.ResponseWriter, r *http.Request) { + if a.s == nil { + http.Error(w, "not ready", http.StatusServiceUnavailable) + return + } + if a.s.Ready() { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok\n") + return + } + http.Error(w, "not ready", http.StatusServiceUnavailable) +} diff --git a/internal/server/doh.go b/internal/server/doh.go index 2f9dfc0..0feb094 100644 --- a/internal/server/doh.go +++ b/internal/server/doh.go @@ -2,13 +2,27 @@ package server import ( "encoding/base64" - "codeberg.org/miekg/dns" "io" "log/slog" "net/http" + + "codeberg.org/miekg/dns" + "linum/internal/cache" ) func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { + clientIP := parseHTTPClientIP(r.RemoteAddr) + if !s.isAllowed(clientIP) { + slog.Warn("doh query denied by ACL", "client", clientIP) + http.Error(w, "refused", http.StatusForbidden) + return + } + if !s.rateLimit(clientIP.String()) { + slog.Warn("doh query rate limited", "client", clientIP) + http.Error(w, "too many requests", http.StatusTooManyRequests) + return + } + var raw []byte switch r.Method { @@ -53,23 +67,55 @@ func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { return } + q := msg.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class + + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { + httpWriteMsg(w, s.blockedResponse(msg)) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "blocked", true, + ) + return + } + + if s.cache != nil { + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} + if packed, ok := s.cache.Get(key); ok { + packed[0] = byte(msg.ID >> 8) + packed[1] = byte(msg.ID) + w.Header().Set("Content-Type", "application/dns-message") + w.Header().Set("Cache-Control", "no-cache, max-age=0") + w.Write(packed) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "cached", true, + ) + return + } + } + resp, _ := s.buildResponse(msg) - if err := resp.Pack(); err != nil { + httpWriteMsg(w, resp) + slog.Debug("doh query served", + "qname", qname, + "qtype", dns.TypeToString[qtype], + "rcode", dns.RcodeToString[resp.Rcode], + ) +} + +func httpWriteMsg(w http.ResponseWriter, msg *dns.Msg) { + if err := msg.Pack(); err != nil { http.Error(w, "pack response", http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") - if _, err := w.Write(resp.Data); err != nil { + if _, err := w.Write(msg.Data); err != nil { slog.Error("doh write failed", "err", err) - return } - - slog.Info("doh query served", - "qname", msg.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(msg.Question[0])], - "rcode", dns.RcodeToString[resp.Rcode], - "client", r.RemoteAddr, - ) } diff --git a/internal/server/handler.go b/internal/server/handler.go index d0aa705..4516468 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "net" "net/netip" "time" @@ -14,6 +15,18 @@ import ( ) func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) { + clientIP := parseClientIP(w.RemoteAddr()) + if !s.isAllowed(clientIP) { + slog.Warn("query denied by ACL", "client", clientIP) + s.writeRefused(w, req) + return + } + if !s.rateLimit(clientIP.String()) { + slog.Warn("query rate limited", "client", clientIP) + s.writeRefused(w, req) + return + } + if len(req.Question) == 0 { m := new(dns.Msg) m.Rcode = dns.RcodeFormatError @@ -23,27 +36,42 @@ func (s *Server) handleQuery(ctx context.Context, w dns.ResponseWriter, req *dns io.Copy(w, m) return } + q := req.Question[0] + qname := q.Header().Name + qtype := dns.RRToType(q) + qclass := q.Header().Class - resp, blocked := s.buildResponse(req) + if s.blocklist != nil && s.blocklist.IsBlocked(qname) { + resp := s.blockedResponse(req) + resp.ID = req.ID + io.Copy(w, resp) + slog.Info("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], + "blocked", true) + return + } + if s.cache != nil { + key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} + if packed, ok := s.cache.Get(key); ok { + packed[0] = byte(req.ID >> 8) + packed[1] = byte(req.ID) + io.Copy(w, &dns.Msg{Data: packed}) + slog.Debug("query served", "qname", qname, "qtype", + dns.TypeToString[qtype], "cached", true) + return + } + } + resp, blocked := s.buildResponse(req) resp.ID = req.ID resp.Data = nil if _, err := io.Copy(w, resp); err != nil { - slog.Error("write response failed", - "err", err, - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - ) + slog.Error("write failed", "err", err, "qname", qname) return } - slog.Info("query served", - "qname", req.Question[0].Header().Name, - "qtype", dns.TypeToString[dns.RRToType(req.Question[0])], - "rcode", dns.RcodeToString[resp.Rcode], - "blocked", blocked, - ) + slog.Info("query served", "qname", qname, "qtype", dns.TypeToString[qtype], + "rcode", dns.RcodeToString[resp.Rcode], "blocked", blocked) } - func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if len(req.Question) == 0 { m := new(dns.Msg) @@ -53,30 +81,19 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { m.Question = req.Question return m, false } + q := req.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class - if s.blocklist != nil && s.blocklist.IsBlocked(qname) { - return s.blockedResponse(req), true - } - - if s.cache != nil { - key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} - if cached, ok := s.cache.Get(key); ok { - cached.ID = req.ID - return cached, false - } - } - ctx, cancel := context.WithTimeout(s.baseCtx, 10*time.Second) defer cancel() reply, err := s.resolver.Lookup(ctx, qname, qtype) if err != nil { slog.Error("resolution failed", - "err", err, + "err", err, "qname", qname, "qtype", dns.TypeToString[qtype], ) @@ -124,9 +141,59 @@ func (s *Server) blockedResponse(req *dns.Msg) *dns.Msg { }) case dns.TypeAAAA: m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, + Hdr: dns.Header{Name: qname, Class: dns.ClassINET, TTL: 60}, AAAA: rdata.AAAA{Addr: netip.IPv6Unspecified()}, }) } return m } + +func parseClientIP(addr net.Addr) net.IP { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP + case *net.TCPAddr: + return a.IP + } + host, _, _ := net.SplitHostPort(addr.String()) + return net.ParseIP(host) +} + +func parseHTTPClientIP(addr string) net.IP { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return net.ParseIP(addr) + } + return net.ParseIP(host) +} + +func (s *Server) isAllowed(ip net.IP) bool { + if ip == nil { + return false + } + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + for _, n := range s.aclNets { + if n.Contains(ip) { + return true + } + } + return false +} + +func (s *Server) rateLimit(ip string) bool { + if s.rateLimiter == nil { + return true + } + return s.rateLimiter.allow(ip) +} + +func (s *Server) writeRefused(w dns.ResponseWriter, req *dns.Msg) { + resp := new(dns.Msg) + resp.Response = true + resp.ID = req.ID + resp.Question = req.Question + resp.Rcode = dns.RcodeRefused + io.Copy(w, resp) +} diff --git a/internal/server/server.go b/internal/server/server.go index 1aa3256..a90e5ac 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,55 +2,95 @@ package server import ( "context" + "fmt" "log/slog" + "net" "net/http" + "sync" "time" - "codeberg.org/miekg/dns" - "linum/internal/resolver" "linum/internal/blocklist" "linum/internal/cache" + "linum/internal/config" + "linum/internal/resolver" + + "codeberg.org/miekg/dns" ) type Server struct { - logger *slog.Logger - resolver *resolver.Resolver - cache *cache.Cache - blocklist *blocklist.Blocklist - udp *dns.Server - tcp *dns.Server - doh *http.Server - baseCtx context.Context + resolver *resolver.Resolver + cache *cache.Cache + blocklist *blocklist.Blocklist + logger *slog.Logger + cfg config.ServerConfig + baseCtx context.Context + udp *dns.Server + tcp *dns.Server + doh *http.Server + admin *Admin + aclNets []*net.IPNet + rateLimiter *rateLimiter + + mu sync.RWMutex + upUDP bool + upTCP bool + upDoH bool + cancel context.CancelFunc } -func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, -r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { +func (s *Server) Ready() bool { + s.mu.RLock() + defer s.mu.RUnlock() + wantUDP := s.cfg.ListenUDP != "" + wantTCP := s.cfg.ListenTCP != "" + wantDoH := s.cfg.ListenDOH != "" + return (!wantUDP || s.upUDP) && (!wantTCP || s.upTCP) && (!wantDoH || s.upDoH) + +} +func New(udpAddr, tcpAddr, dohAddr string, logger *slog.Logger, r *resolver.Resolver, c *cache.Cache, bl *blocklist.Blocklist, cfg config.Config) (*Server, error) { baseCtx, cancel := context.WithCancel(context.Background()) + s := &Server{ + logger: logger, + resolver: r, + cache: c, + blocklist: bl, + cfg: cfg.Server, + baseCtx: baseCtx, + cancel: cancel, + aclNets: make([]*net.IPNet, 0, len(cfg.ACL.Allow)), + rateLimiter: newRateLimiter(cfg.ACL.RateLimitQPS, cfg.ACL.RateLimitBurst), + } - s := &Server{logger: logger, resolver: r, cache: c, blocklist: b, - baseCtx: baseCtx, cancel: cancel} + for _, cidr := range cfg.ACL.Allow { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + cancel() + return nil, fmt.Errorf("invalid acl CIDR %q: %w", cidr, err) + } + s.aclNets = append(s.aclNets, ipnet) + } mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) if udpAddr != "" { s.udp = &dns.Server{ - Addr: udpAddr, - Net: "udp", - Handler: mux, - UDPSize: 4096, - ReadTimeout: 5 * time.Second, - ReusePort: true, + Addr: udpAddr, + Net: "udp", + Handler: mux, + UDPSize: 4096, + ReadTimeout: 5 * time.Second, + ReusePort: true, } } if tcpAddr != "" { s.tcp = &dns.Server{ - Addr: tcpAddr, - Net: "tcp", - Handler: mux, - ReadTimeout: 5 * time.Second, + Addr: tcpAddr, + Net: "tcp", + Handler: mux, + ReadTimeout: 5 * time.Second, } } @@ -64,9 +104,30 @@ r *resolver.Resolver, c *cache.Cache, b *blocklist.Blocklist) (*Server, error) { WriteTimeout: 5 * time.Second, } } + + if cfg.Admin.Listen != "" { + s.admin = NewAdmin(cfg.Admin.Listen, s) + } + return s, nil } func (s *Server) Run(ctx context.Context) error { + s.mu.Lock() + if s.udp != nil { + s.upUDP = true + } + if s.tcp != nil { + s.upTCP = true + } + if s.doh != nil { + s.upDoH = true + } + s.mu.Unlock() + + if s.admin != nil { + s.admin.Start() + } + errCh := make(chan error, 3) if s.udp != nil { go func() { @@ -110,6 +171,9 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) Close() error { + if s.admin != nil { + s.admin.Close() + } if s.udp != nil { s.udp.Shutdown(context.Background()) } @@ -121,3 +185,54 @@ func (s *Server) Close() error { } return nil } + +type rateLimiter struct { + mu sync.Mutex + rate float64 + burst float64 + buckets map[string]*rateBucket +} + +type rateBucket struct { + tokens float64 + last time.Time +} + +func newRateLimiter(qps, burst int) *rateLimiter { + if qps <= 0 && burst <= 0 { + return nil + } + if burst <= 0 { + burst = qps + } + return &rateLimiter{ + rate: float64(qps), + burst: float64(burst), + buckets: make(map[string]*rateBucket), + } +} + +func (rl *rateLimiter) allow(ip string) bool { + now := time.Now() + rl.mu.Lock() + defer rl.mu.Unlock() + + b, ok := rl.buckets[ip] + if !ok { + b = &rateBucket{tokens: rl.burst, last: now} + rl.buckets[ip] = b + } + + elapsed := now.Sub(b.last).Seconds() + b.tokens += elapsed * rl.rate + if b.tokens > rl.burst { + b.tokens = rl.burst + } + b.last = now + + if b.tokens < 1 { + return false + } + b.tokens-- + return true +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 002aba8..2340acd 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,8 +1,8 @@ package server import ( - "log/slog" "context" + "log/slog" "testing" "time" -- cgit v1.2.3