summaryrefslogtreecommitdiff
path: root/internal/resolver/resolver.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/resolver/resolver.go')
-rw-r--r--internal/resolver/resolver.go91
1 files changed, 74 insertions, 17 deletions
diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go
index 9ee9cd6..5aa7bc1 100644
--- a/internal/resolver/resolver.go
+++ b/internal/resolver/resolver.go
@@ -7,7 +7,7 @@ import (
"net"
"time"
- "github.com/miekg/dns"
+ "codeberg.org/miekg/dns"
)
var (
@@ -26,16 +26,15 @@ type Resolver struct {
type Option func(*Resolver)
func New(opts ...Option) *Resolver {
+ transport := dns.NewTransport()
+ transport.ReadTimeout = 2 * time.Second
+
r := &Resolver{
roots: loadRootServers(),
maxDelegations: 30,
timeout: 2 * time.Second,
retries: 2,
- client: &dns.Client{
- Net: "udp",
- UDPSize: 4096,
- Timeout: 2 * time.Second,
- },
+ client: &dns.Client{Transport: transport},
}
for _, opt := range opts {
opt(r)
@@ -52,6 +51,7 @@ func WithRootAddresses(addrs []string) Option {
func WithTimeout(d time.Duration) Option {
return func(r *Resolver) {
r.timeout = d
+ r.client.Transport.ReadTimeout = d
}
}
@@ -81,6 +81,9 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
}
switch {
case reply.Rcode == dns.RcodeSuccess && len(reply.Answer) > 0:
+ if needsCNAMEResolution(reply, qtype) {
+ return r.resolveCNAME(ctx, reply, qtype)
+ }
return reply, nil
case reply.Rcode == dns.RcodeNameError:
return reply, nil
@@ -104,6 +107,62 @@ func (r *Resolver) resolve(ctx context.Context, qname string, qtype uint16) (*dn
return nil, ErrMaxDelegations
}
+func needsCNAMEResolution(reply *dns.Msg, qtype uint16) bool {
+ hasCNAME := false
+ hasTarget := false
+ for _, rr := range reply.Answer {
+ if _, ok := rr.(*dns.CNAME); ok {
+ hasCNAME = true
+ }
+ if dns.RRToType(rr) == qtype && dns.RRToType(rr) != dns.TypeCNAME {
+ hasTarget = true
+ }
+ }
+ return hasCNAME && !hasTarget
+}
+func (r *Resolver) resolveCNAME(ctx context.Context, reply *dns.Msg, qtype uint16) (*dns.Msg, error) {
+ var target string
+ for _, rr := range reply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok {
+ target = cn.Target
+ }
+ }
+ if target == "" {
+ return reply, nil
+ }
+
+ const maxChain = 10
+ seen := map[string]bool{reply.Question[0].Header().Name: true}
+ for i := 0; i < maxChain; i++ {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+ if seen[target] {
+ break
+ }
+ seen[target] = true
+
+ chainReply, err := r.resolve(ctx, target, qtype)
+ if err != nil {
+ return reply, nil
+ }
+ reply.Answer = append(reply.Answer, chainReply.Answer...)
+
+ nextTarget := ""
+ for _, rr := range chainReply.Answer {
+ if cn, ok := rr.(*dns.CNAME); ok && cn.Header().Name == target {
+ nextTarget = cn.Target
+ }
+ }
+ if nextTarget == "" {
+ break
+ }
+ target = nextTarget
+ }
+ return reply, nil
+}
func isReferral(msg *dns.Msg) bool {
return !msg.Authoritative && len(msg.Ns) > 0
}
@@ -122,12 +181,12 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
for _, rr := range msg.Extra {
switch v := rr.(type) {
case *dns.A:
- if _, exists := glue[v.Hdr.Name]; !exists {
- glue[v.Hdr.Name] = v.A.String()
+ if _, exists := glue[v.Header().Name]; !exists {
+ glue[v.Header().Name] = v.A.Addr.String()
}
case *dns.AAAA:
- if _, exists := glue[v.Hdr.Name]; !exists {
- glue[v.Hdr.Name] = v.AAAA.String()
+ if _, exists := glue[v.Header().Name]; !exists {
+ glue[v.Header().Name] = v.AAAA.Addr.String()
}
}
}
@@ -156,7 +215,7 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
- addrs = append(addrs, a.A.String())
+ addrs = append(addrs, a.A.Addr.String())
break
}
}
@@ -171,9 +230,8 @@ func (r *Resolver) nextServers(ctx context.Context, msg *dns.Msg) ([]string, err
func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
qname string, qtype uint16) (*dns.Msg, error) {
- msg := new(dns.Msg)
- msg.SetQuestion(qname, qtype)
- msg.SetEdns0(4096, false)
+ msg := dns.NewMsg(qname, qtype)
+ msg.UDPSize = 4096
msg.RecursionDesired = false
type result struct {
@@ -189,11 +247,10 @@ func (r *Resolver) exchangeWithRetries(ctx context.Context, servers []string,
}
go func(addr string) {
for attempt := 0; attempt < r.retries; attempt++ {
- reply, _, err := r.client.ExchangeContext(ctx, msg, addr)
+ reply, _, err := r.client.Exchange(ctx, msg, "udp", addr)
if err == nil {
if reply.Truncated {
- tcpClient := &dns.Client{Net: "tcp", Timeout: r.timeout}
- reply, _, err = tcpClient.ExchangeContext(ctx, msg, addr)
+ reply, _, err = r.client.Exchange(ctx, msg, "tcp", addr)
if err == nil {
ch <- result{reply: reply, err: nil}
return