1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
//! Implementation of a DNS name server for iroh node announces

use std::{
    collections::BTreeMap,
    io,
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
    sync::Arc,
    time::Duration,
};

use anyhow::{anyhow, Result};
use async_trait::async_trait;
use bytes::Bytes;
use hickory_server::{
    authority::{Catalog, MessageResponse, ZoneType},
    proto::{
        rr::{
            rdata::{self},
            RData, Record, RecordSet, RecordType, RrKey,
        },
        serialize::{binary::BinEncoder, txt::RDataParser},
        {self},
    },
    resolver::Name,
    server::{Request, RequestHandler, ResponseHandler, ResponseInfo},
    store::in_memory::InMemoryAuthority,
};
use iroh_metrics::inc;
use proto::{op::ResponseCode, rr::LowerName};
use serde::{Deserialize, Serialize};
use tokio::{
    net::{TcpListener, UdpSocket},
    sync::broadcast,
};
use tracing::{debug, info};

use self::node_authority::NodeAuthority;
use crate::{metrics::Metrics, store::ZoneStore};

mod node_authority;

const DEFAULT_NS_TTL: u32 = 60 * 60 * 12; // 12h
const DEFAULT_SOA_TTL: u32 = 60 * 60 * 24 * 14; // 14d
const DEFAULT_A_TTL: u32 = 60 * 60; // 1h

/// DNS server settings
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DnsConfig {
    /// The port to serve a local UDP DNS server at
    pub port: u16,
    /// The IPv4 or IPv6 address to bind the UDP DNS server.
    /// Uses `0.0.0.0` if unspecified.
    pub bind_addr: Option<IpAddr>,
    /// SOA record data for any authoritative DNS records
    pub default_soa: String,
    /// Default time to live for returned DNS records (TXT & SOA)
    pub default_ttl: u32,
    /// Domain used for serving the `_iroh_node.<nodeid>.<origin>` DNS TXT entry
    pub origins: Vec<String>,

    /// `A` record to set for all origins
    pub rr_a: Option<Ipv4Addr>,
    /// `AAAA` record to set for all origins
    pub rr_aaaa: Option<Ipv6Addr>,
    /// `NS` record to set for all origins
    pub rr_ns: Option<String>,
}

/// A DNS server that serves pkarr signed packets.
pub struct DnsServer {
    local_addr: SocketAddr,
    server: hickory_server::ServerFuture<DnsHandler>,
}

impl DnsServer {
    /// Spawn the server.
    pub async fn spawn(config: DnsConfig, dns_handler: DnsHandler) -> Result<Self> {
        const TCP_TIMEOUT: Duration = Duration::from_millis(1000);
        let mut server = hickory_server::ServerFuture::new(dns_handler);

        let bind_addr = SocketAddr::new(
            config.bind_addr.unwrap_or(Ipv4Addr::UNSPECIFIED.into()),
            config.port,
        );

        let socket = UdpSocket::bind(bind_addr).await?;

        let socket_addr = socket.local_addr()?;

        server.register_socket(socket);
        server.register_listener(TcpListener::bind(bind_addr).await?, TCP_TIMEOUT);
        info!("DNS server listening on {}", bind_addr);

        Ok(Self {
            server,
            local_addr: socket_addr,
        })
    }

    /// Get the local address of the UDP/TCP socket.
    pub fn local_addr(&self) -> SocketAddr {
        self.local_addr
    }

    /// Shutdown the server an wait for all tasks to complete.
    pub async fn shutdown(mut self) -> Result<()> {
        self.server.shutdown_gracefully().await?;
        Ok(())
    }

    /// Wait for all tasks to complete.
    ///
    /// Runs forever unless tasks fail.
    pub async fn run_until_done(mut self) -> Result<()> {
        self.server.block_until_done().await?;
        Ok(())
    }
}

/// State for serving DNS
#[derive(Clone, derive_more::Debug)]
pub struct DnsHandler {
    #[debug("Catalog")]
    catalog: Arc<Catalog>,
}

impl DnsHandler {
    /// Create a DNS server given some settings, a connection to the DB for DID-by-username lookups
    /// and the server DID to serve under `_did.<origin>`.
    pub fn new(zone_store: ZoneStore, config: &DnsConfig) -> Result<Self> {
        let origins = config
            .origins
            .iter()
            .map(Name::from_utf8)
            .collect::<Result<Vec<_>, _>>()?;

        let (static_authority, serial) = create_static_authority(&origins, config)?;
        let authority = NodeAuthority::new(zone_store, static_authority, origins, serial)?;
        let authority = Arc::new(authority);

        let mut catalog = Catalog::new();
        for origin in authority.origins() {
            catalog.upsert(LowerName::from(origin), Box::new(Arc::clone(&authority)));
        }

        Ok(Self {
            catalog: Arc::new(catalog),
        })
    }

    /// Handle a DNS request
    pub async fn answer_request(&self, request: Request) -> Result<Bytes> {
        let (tx, mut rx) = broadcast::channel(1);
        let response_handle = Handle(tx);
        self.handle_request(&request, response_handle).await;
        Ok(rx.recv().await?)
    }
}

#[async_trait::async_trait]
impl RequestHandler for DnsHandler {
    async fn handle_request<R: ResponseHandler>(
        &self,
        request: &Request,
        response_handle: R,
    ) -> ResponseInfo {
        inc!(Metrics, dns_requests);
        match request.protocol() {
            hickory_server::server::Protocol::Udp => inc!(Metrics, dns_requests_udp),
            hickory_server::server::Protocol::Https => inc!(Metrics, dns_requests_https),
            _ => {}
        }
        debug!(protocol=%request.protocol(), query=%request.query(), "incoming DNS request");

        let res = self.catalog.handle_request(request, response_handle).await;
        match &res.response_code() {
            ResponseCode::NoError => match res.answer_count() {
                0 => inc!(Metrics, dns_lookup_notfound),
                _ => inc!(Metrics, dns_lookup_success),
            },
            ResponseCode::NXDomain => inc!(Metrics, dns_lookup_notfound),
            _ => inc!(Metrics, dns_lookup_error),
        }
        res
    }
}

/// A handle to the channel over which the response to a DNS request will be sent
#[derive(Debug, Clone)]
pub struct Handle(pub broadcast::Sender<Bytes>);

#[async_trait]
impl ResponseHandler for Handle {
    async fn send_response<'a>(
        &mut self,
        response: MessageResponse<
            '_,
            'a,
            impl Iterator<Item = &'a proto::rr::Record> + Send + 'a,
            impl Iterator<Item = &'a proto::rr::Record> + Send + 'a,
            impl Iterator<Item = &'a proto::rr::Record> + Send + 'a,
            impl Iterator<Item = &'a proto::rr::Record> + Send + 'a,
        >,
    ) -> io::Result<ResponseInfo> {
        let mut bytes = Vec::with_capacity(512);
        let info = {
            let mut encoder = BinEncoder::new(&mut bytes);
            response.destructive_emit(&mut encoder)?
        };

        let bytes = Bytes::from(bytes);
        self.0.send(bytes).unwrap();

        Ok(info)
    }
}

fn create_static_authority(
    origins: &[Name],
    config: &DnsConfig,
) -> Result<(InMemoryAuthority, u32)> {
    let soa = RData::parse(
        RecordType::SOA,
        config.default_soa.split_ascii_whitespace(),
        None,
    )?
    .into_soa()
    .map_err(|_| anyhow!("Couldn't parse SOA: {}", config.default_soa))?;
    let serial = soa.serial();
    let mut records = BTreeMap::new();
    for name in origins {
        push_record(
            &mut records,
            serial,
            Record::from_rdata(name.clone(), DEFAULT_SOA_TTL, RData::SOA(soa.clone())),
        );
        if let Some(addr) = config.rr_a {
            push_record(
                &mut records,
                serial,
                Record::from_rdata(name.clone(), DEFAULT_A_TTL, RData::A(addr.into())),
            );
        }
        if let Some(addr) = config.rr_aaaa {
            push_record(
                &mut records,
                serial,
                Record::from_rdata(name.clone(), DEFAULT_A_TTL, RData::AAAA(addr.into())),
            );
        }
        if let Some(ns) = &config.rr_ns {
            let ns = Name::parse(ns, Some(&Name::root()))?;
            push_record(
                &mut records,
                serial,
                Record::from_rdata(name.clone(), DEFAULT_NS_TTL, RData::NS(rdata::NS(ns))),
            );
        }
    }

    let static_authority = InMemoryAuthority::new(Name::root(), records, ZoneType::Primary, false)
        .map_err(|e| anyhow!(e))?;

    Ok((static_authority, serial))
}

fn push_record(records: &mut BTreeMap<RrKey, RecordSet>, serial: u32, record: Record) {
    let key = RrKey::new(record.name().clone().into(), record.record_type());
    let mut record_set = RecordSet::new(record.name(), record.record_type(), serial);
    record_set.insert(record, serial);
    records.insert(key, record_set);
}