From d173554892339e5211020c60d6af610840eef7ed Mon Sep 17 00:00:00 2001 From: radhitya Date: Sun, 14 Jun 2026 17:17:56 +0700 Subject: config, rebranding, fix cache --- internal/cache/cache.go | 9 ++- internal/config/config.go | 152 +++++++++++++++++++++++++++++++++++++++++ internal/server/handler.go | 6 +- internal/server/server.go | 6 +- internal/server/server_test.go | 2 +- 5 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 internal/config/config.go (limited to 'internal') diff --git a/internal/cache/cache.go b/internal/cache/cache.go index a2d86a0..d6a31f3 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -2,6 +2,7 @@ package cache import ( "database/sql" + "log/slog" "sync" "sync/atomic" "time" @@ -199,16 +200,22 @@ func (c *Cache) evictLoop() { } } func (c *Cache) writeToDB(key Key, e *entry) { + if c.db == nil { + return + } data, err := e.msg.Pack() if err != nil { return } - c.db.Exec( + _, err = c.db.Exec( `INSERT OR REPLACE INTO cache (name, qtype, class, data, stored_at, ttl_ns) VALUES (?, ?, ?, ?, ?, ?)`, key.Name, key.Qtype, key.Class, data, e.storedAt.UnixNano(), int64(e.ttl), ) + if err != nil { + slog.Warn("cache write to db failed", "err", err) + } } func (c *Cache) loadFromDB() { diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..f2624c2 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,152 @@ +package config + +import ( + "flag" + "fmt" + + "github.com/BurntSushi/toml" +) + +type Config struct { + Server ServerConfig `toml:"server"` + Cache CacheConfig `toml:"cache"` + Resolver ResolverConfig `toml:"resolver"` + Blocklist BlocklistConfig `toml:"blocklist"` + Log LogConfig `toml:"log"` +} + +type ServerConfig struct { + ListenUDP string `toml:"listen_udp"` + ListenTCP string `toml:"listen_tcp"` + ListenDOH string `toml:"listen_doh"` +} + +type CacheConfig struct { + MaxEntries int `toml:"max_entries"` + DBPath string `toml:"db_path"` +} + +type ResolverConfig struct { + Timeout string `toml:"timeout"` + MaxDelegations int `toml:"max_delegations"` +} + +type BlocklistConfig struct { + Response string `toml:"response"` + Files []string `toml:"files"` + URLs []string `toml:"urls"` +} + +type LogConfig struct { + Level string `toml:"level"` +} + +type CLIFlags struct { + Config string + LogLevel string + ListenUDP string + ListenTCP string + ListenDOH string +} + +func ParseFlags() CLIFlags { + var f CLIFlags + flag.StringVar(&f.Config, "config", "linum.toml", "path to config file") + flag.StringVar(&f.LogLevel, "loglevel", "", "log level (debug|info|warn|error)") + flag.StringVar(&f.ListenUDP, "udp", "", "UDP listen address") + flag.StringVar(&f.ListenTCP, "tcp", "", "TCP listen address") + flag.StringVar(&f.ListenDOH, "doh", "", "DoH listen address") + flag.Parse() + return f +} + +func Default() Config{ + return Config{ + Server: ServerConfig{ + ListenUDP: ":5353", + ListenTCP: ":5353", + ListenDOH: ":8443", + }, + Cache: CacheConfig{ + MaxEntries: 100000, + }, + Resolver: ResolverConfig{ + Timeout: "2s", + MaxDelegations: 30, + }, + Blocklist: BlocklistConfig{ + Response: "zero_ip", + }, + Log: LogConfig{ + Level: "info", + }, + } +} + +func LoadFile(path string) (Config, error) { + var cfg Config + _, err := toml.DecodeFile(path, &cfg) + return cfg, err +} + +func Merge(dst, src Config) Config { + if src.Server.ListenUDP != "" { + dst.Server.ListenUDP = src.Server.ListenUDP + } + if src.Server.ListenTCP != "" { + dst.Server.ListenTCP = src.Server.ListenTCP + } + if src.Server.ListenDOH != "" { + dst.Server.ListenDOH = src.Server.ListenDOH + } + if src.Cache.MaxEntries > 0 { + dst.Cache.MaxEntries = src.Cache.MaxEntries + } + if src.Cache.DBPath != "" { + dst.Cache.DBPath = src.Cache.DBPath + } + if src.Resolver.Timeout != "" { + dst.Resolver.Timeout = src.Resolver.Timeout + } + if src.Resolver.MaxDelegations > 0 { + dst.Resolver.MaxDelegations = src.Resolver.MaxDelegations + } + if src.Blocklist.Response != "" { + dst.Blocklist.Response = src.Blocklist.Response + } + if len(src.Blocklist.Files) > 0 { + dst.Blocklist.Files = src.Blocklist.Files + } + if len(src.Blocklist.URLs) > 0 { + dst.Blocklist.URLs = src.Blocklist.URLs + } + if src.Log.Level != "" { + dst.Log.Level = src.Log.Level + } + return dst +} + +func (f CLIFlags) Apply(cfg Config) Config { + if f.ListenUDP != "" { + cfg.Server.ListenUDP = f.ListenUDP + } + if f.ListenTCP != "" { + cfg.Server.ListenTCP = f.ListenTCP + } + if f.ListenDOH != "" { + cfg.Server.ListenDOH = f.ListenDOH + } + if f.LogLevel != "" { + cfg.Log.Level = f.LogLevel + } + return cfg +} + +func (c Config) Validate() error { + switch c.Blocklist.Response { + case "zero_ip", "nxdomain", "": + default: + return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response) + } + return nil +} diff --git a/internal/server/handler.go b/internal/server/handler.go index 4aa771f..406b7ed 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -7,8 +7,8 @@ import ( "time" "github.com/miekg/dns" - "sdns/internal/blocklist" - "sdns/internal/cache" + "linum/internal/blocklist" + "linum/internal/cache" ) func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { @@ -33,7 +33,6 @@ func (s *Server) handleQuery(w dns.ResponseWriter, req *dns.Msg) { "qname", req.Question[0].Name, "qtype", dns.TypeToString[req.Question[0].Qtype], "rcode", dns.RcodeToString[resp.Rcode], - "client", w.RemoteAddr().String(), "blocked", blocked, ) } @@ -47,6 +46,7 @@ func (s *Server) buildResponse(req *dns.Msg) (*dns.Msg, bool) { if s.cache != nil { key := cache.Key{Name: q.Name, Qtype: q.Qtype, Class: q.Qclass} if cached, ok := s.cache.Get(key); ok { + cached.Id = req.Id return cached, false } } diff --git a/internal/server/server.go b/internal/server/server.go index e0490bd..ec0dec9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,9 +7,9 @@ import ( "time" "github.com/miekg/dns" - "sdns/internal/resolver" - "sdns/internal/blocklist" - "sdns/internal/cache" + "linum/internal/resolver" + "linum/internal/blocklist" + "linum/internal/cache" ) type Server struct { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index c49d5f3..e59a131 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/miekg/dns" - "sdns/internal/resolver" + "linum/internal/resolver" ) func testServer(t *testing.T) *Server { -- cgit v1.2.3