package server import ( "encoding/base64" "io" "log/slog" "net/http" "codeberg.org/miekg/dns" "linum/internal/cache" ) func decodeDNSParam(s string) ([]byte, error) { if b, err := base64.RawURLEncoding.DecodeString(s); err == nil { return b, nil } return base64.URLEncoding.DecodeString(s) } func (s *Server) dohHandler(w http.ResponseWriter, r *http.Request) { clientIP := parseHTTPClientIP(r.RemoteAddr) if !s.isAllowed(clientIP) { slog.Warn("doh query denied by ACL", "client", clientIP) http.Error(w, "refused", http.StatusForbidden) return } if !s.rateLimit(clientIP.String()) { slog.Warn("doh query rate limited", "client", clientIP) http.Error(w, "too many requests", http.StatusTooManyRequests) return } var raw []byte switch r.Method { case http.MethodPost: ct := r.Header.Get("Content-Type") if ct != "application/dns-message" { http.Error(w, "unsupported content type", http.StatusUnsupportedMediaType) return } body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 65535)) if err != nil { http.Error(w, "read body", http.StatusBadRequest) return } raw = body case http.MethodGet: param := r.URL.Query().Get("dns") if param == "" { http.Error(w, "missing dns param", http.StatusBadRequest) return } decoded, err := decodeDNSParam(param) if err != nil { http.Error(w, "invalid base64url", http.StatusBadRequest) return } raw = decoded default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } msg := new(dns.Msg) msg.Data = raw if err := msg.Unpack(); err != nil { http.Error(w, "invalid dns message", http.StatusBadRequest) return } if len(msg.Question) == 0 { http.Error(w, "no question", http.StatusBadRequest) return } q := msg.Question[0] qname := q.Header().Name qtype := dns.RRToType(q) qclass := q.Header().Class if s.blocklist != nil && s.blocklist.IsBlocked(qname) { httpWriteMsg(w, s.blockedResponse(msg)) slog.Debug("doh query served", "qname", qname, "qtype", dns.TypeToString[qtype], "blocked", true, ) return } if s.cache != nil { key := cache.Key{Name: qname, Qtype: qtype, Class: qclass} if packed, ok := s.cache.Get(key); ok { packed[0] = byte(msg.ID >> 8) packed[1] = byte(msg.ID) w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") w.Write(packed) slog.Debug("doh query served", "qname", qname, "qtype", dns.TypeToString[qtype], "cached", true, ) return } } resp, _ := s.buildResponse(msg) httpWriteMsg(w, resp) slog.Debug("doh query served", "qname", qname, "qtype", dns.TypeToString[qtype], "rcode", dns.RcodeToString[resp.Rcode], ) } func httpWriteMsg(w http.ResponseWriter, msg *dns.Msg) { if err := msg.Pack(); err != nil { http.Error(w, "pack response", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", "no-cache, max-age=0") if _, err := w.Write(msg.Data); err != nil { slog.Error("doh write failed", "err", err) } }