summaryrefslogtreecommitdiff
path: root/internal/resolver/resolver.go
diff options
context:
space:
mode:
authorradhitya <alif@radhitya.org>2026-06-18 18:17:19 +0700
committerradhitya <alif@radhitya.org>2026-06-18 18:17:19 +0700
commit359f6a1cba3f2e281cefa727db34e3497dc15a2c (patch)
treeefaeeef69c903127b28c51f962de1714bd121ecd /internal/resolver/resolver.go
parentf5753c6a8cac5a57a042b0388f38abeff5d1f37d (diff)
add custom forwarders
Diffstat (limited to 'internal/resolver/resolver.go')
-rw-r--r--internal/resolver/resolver.go90
1 files changed, 55 insertions, 35 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 5aa7bc1..1e4f1da 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -20,6 +20,7 @@ type Resolver struct {
maxDelegations int
timeout time.Duration
retries int
+ forwarders []string
client *dns.Client
}
@@ -42,6 +43,11 @@ func New(opts ...Option) *Resolver {
return r
}
+func WithForwarders(addrs []string) Option {
+ return func(r *Resolver) {
+ r.forwarders = addrs
+ }
+}
func WithRootAddresses(addrs []string) Option {
return func(r *Resolver) {
r.roots = addrs
@@ -59,9 +65,22 @@ func (r *Resolver) Lookup(ctx context.Context, qname string, qtype uint16) (*dns
if ctx == nil {
ctx = context.Background()
}
+ if len(r.forwarders) > 0 {
+ return r.forward(ctx, qname, qtype)
+ }
return r.resolve(ctx, qname, qtype)
}
+func (r *Resolver) forward(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
+ servers := make([]string, len(r.forwarders))
+ copy(servers, r.forwarders)
+ reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype, true)
+ if err != nil {
+ return nil, fmt.Errorf("forward %s %s: %w",
+ qname, dns.TypeToString[qtype],err)
+ }
+ return reply, nil
+}
func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dns.Msg, error) {
servers := make([]string, len(r.roots))
copy(servers, r.roots)
@@ -74,10 +93,11 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
if len(servers) == 0 {
return nil, ErrNoServers
}
- reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype)
+ reply, err := r.exchangeWithRetries(ctx, servers, qname, qtype,
+ false)
if err != nil {
return nil, fmt.Errorf("resolve %s %s: %w",
- qname, dns.TypeToString[qtype], err)
+ qname, dns.TypeToString[qtype], err)
}
switch {
case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
@@ -128,40 +148,40 @@ func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint1
}
}
if target == "" {
- return reply, nil
- }
+ return reply, nil
+ }
- const maxChain = 10
- seen := map[string]bool{reply.Question[0].Header().Name: true}
- for i := 0; i < maxChain; i++ {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- }
- if seen[target] {
- break
- }
- seen[target] = true
+ const maxChain = 10
+ seen := map[string]bool{reply.Question[0].Header().Name: true}
+ for i := 0; i < maxChain; i++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+ if seen[target] {
+ break
+ }
+ seen[target] = true
- chainReply, err := r.resolve(ctx, target, qtype)
- if err != nil {
- return reply, nil
- }
- reply.Answer = append(reply.Answer, chainReply.Answer...)
+ chainReply, err := r.resolve(ctx, target, qtype)
+ if err != nil {
+ return reply, nil
+ }
+ reply.Answer = append(reply.Answer, chainReply.Answer...)
- nextTarget := ""
- for _, rr := range chainReply.Answer {
- if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
- nextTarget = cn.Target
- }
- }
- if nextTarget == "" {
- break
- }
- target = nextTarget
- }
- return reply, nil
+ nextTarget := ""
+ for _, rr := range chainReply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
+ nextTarget = cn.Target
+ }
+ }
+ if nextTarget == "" {
+ break
+ }
+ target = nextTarget
+ }
+ return reply, nil
}
func isReferral(msg *dns.Msg) bool {
return !msg.Authoritative && len(msg.Ns) > 0
@@ -228,11 +248,11 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
}
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
- qname string, qtype uint16) (*dns.Msg, error) {
+qname string, qtype uint16, rd bool) (*dns.Msg, error) {
msg := dns.NewMsg(qname, qtype)
msg.UDPSize = 4096
- msg.RecursionDesired = false
+ msg.RecursionDesired = rd
type result struct {
reply *dns.Msg