summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-18 18:17:19 +0700
committerradhitya <alif@radhitya.org>2026-06-18 18:17:19 +0700
commit359f6a1cba3f2e281cefa727db34e3497dc15a2c (patch)
treeefaeeef69c903127b28c51f962de1714bd121ecd /internal
parentf5753c6a8cac5a57a042b0388f38abeff5d1f37d (diff)
add custom forwarders
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go19
-rw-r--r--internal/resolver/resolver.go90
2 files changed, 74 insertions, 35 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index f2624c2..1fa8069 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -27,8 +27,10 @@ type CacheConfig struct {
}
type ResolverConfig struct {
+ Mode string `toml:"mode"`
Timeout string `toml:"timeout"`
MaxDelegations int `toml:"max_delegations"`
+ Forwarders []string `toml:"forwarders"`
}
type BlocklistConfig struct {
@@ -71,6 +73,7 @@ func Default() Config{
MaxEntries: 100000,
},
Resolver: ResolverConfig{
+ Mode: "recursive",
Timeout: "2s",
MaxDelegations: 30,
},
@@ -111,6 +114,12 @@ func Merge(dst, src Config) Config {
if src.Resolver.MaxDelegations > 0 {
dst.Resolver.MaxDelegations = src.Resolver.MaxDelegations
}
+ if src.Resolver.Mode != "" {
+ dst.Resolver.Mode = src.Resolver.Mode
+ }
+ if len(src.Resolver.Forwarders) > 0 {
+ dst.Resolver.Forwarders = src.Resolver.Forwarders
+ }
if src.Blocklist.Response != "" {
dst.Blocklist.Response = src.Blocklist.Response
}
@@ -148,5 +157,15 @@ func (c Config) Validate() error {
default:
return fmt.Errorf("invalid blocklist response %q (want zero_ip or nxdomain)", c.Blocklist.Response)
}
+
+ switch c.Resolver.Mode {
+ case "recursive", "forward", "":
+ // nothing happened lol
+ default:
+ return fmt.Errorf("invalid resolver mode %q (recursive or forward)", c.Resolver.Mode)
+ }
+ if c.Resolver.Mode == "forward" && len(c.Resolver.Forwarders) == 0 {
+ return fmt.Errorf("resolver mode=forward requires at least one forwarder")
+ }
return nil
}
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 5aa7bc1..1e4f1da 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -20,6 +20,7 @@ type Resolver struct {
maxDelegations int
timeout time.Duration
retries int
+ forwarders []string
client *dns.Client
}
@@ -42,6 +43,11 @@ func New(opts ...Option) *Resolver {
return r
}
+func WithForwarders(addrs []string) Option {
+ return func(r *Resolver) {
+ r.forwarders = addrs
+ }
+}
func WithRootAddresses(addrs []string) Option {
return func(r *Resolver) {
r.roots = addrs
@@ -59,9 +65,22 @@ func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns
if ctx == nil {
ctx = context.Background()
}
+ if len(r.forwarders) > 0 {
+ return r.forward(ctx, qname, qtype)
+ }
return r.resolve(ctx, qname, qtype)
}
+func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
+ servers := make([]string, len(r.forwarders))
+ copy(servers, r.forwarders)
+ reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true)
+ if err != nil {
+ return nil, fmt.Errorf("forward %s %s: %w",
+ qname, dns.TypeToString[qtype],err)
+ }
+ return reply, nil
+}
func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
servers := make([]string, len(r.roots))
copy(servers, r.roots)
@@ -74,10 +93,11 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
if len(servers) == 0 {
return nil, ErrNoServers
}
- reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype)
+ reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype,
+ false)
if err != nil {
return nil, fmt.Errorf("resolve %s %s: %w",
- qname, dns.TypeToString[qtype], err)
+ qname, dns.TypeToString[qtype], err)
}
switch {
case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
@@ -128,40 +148,40 @@ func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint1
}
}
if target == "" {
- return reply, nil
- }
+ return reply, nil
+ }
- const maxChain = 10
- seen := map[string]bool{reply.Question[0].Header().Name: true}
- for i := 0; i < maxChain; i++ {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- }
- if seen[target] {
- break
- }
- seen[target] = true
+ const maxChain = 10
+ seen := map[string]bool{reply.Question[0].Header().Name: true}
+ for i := 0; i < maxChain; i++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+ if seen[target] {
+ break
+ }
+ seen[target] = true
- chainReply, err := r.resolve(ctx, target, qtype)
- if err != nil {
- return reply, nil
- }
- reply.Answer = append(reply.Answer, chainReply.Answer...)
+ chainReply, err := r.resolve(ctx, target, qtype)
+ if err != nil {
+ return reply, nil
+ }
+ reply.Answer = append(reply.Answer, chainReply.Answer...)
- nextTarget := ""
- for _, rr := range chainReply.Answer {
- if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
- nextTarget = cn.Target
- }
- }
- if nextTarget == "" {
- break
- }
- target = nextTarget
- }
- return reply, nil
+ nextTarget := ""
+ for _, rr := range chainReply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
+ nextTarget = cn.Target
+ }
+ }
+ if nextTarget == "" {
+ break
+ }
+ target = nextTarget
+ }
+ return reply, nil
}
func isReferral(msg *dns.Msg) bool {
return !msg.Authoritative && len(msg.Ns) > 0
@@ -228,11 +248,11 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
}
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
- qname string, qtype uint16) (*dns.Msg, error) {
+qname string, qtype uint16, rd bool) (*dns.Msg, error) {
msg := dns.NewMsg(qname, qtype)
msg.UDPSize = 4096
- msg.RecursionDesired = false
+ msg.RecursionDesired = rd
type result struct {
reply *dns.Msg