1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use core::fmt;
use std::{
    collections::{btree_map, BTreeMap},
    str::FromStr,
    sync::Arc,
};

use anyhow::{anyhow, Result};
use hickory_proto::{
    op::Message,
    rr::{
        domain::{IntoLabel, Label},
        Name, Record, RecordSet, RecordType, RrKey,
    },
    serialize::binary::BinDecodable,
};
use pkarr::SignedPacket;

#[derive(
    derive_more::From, derive_more::Into, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy,
)]
pub struct PublicKeyBytes([u8; 32]);

impl PublicKeyBytes {
    pub fn from_z32(s: &str) -> Result<Self> {
        let bytes = z32::decode(s.as_bytes())?;
        let bytes: [u8; 32] = bytes.try_into().map_err(|_| anyhow!("invalid length"))?;
        Ok(Self(bytes))
    }

    pub fn to_z32(self) -> String {
        z32::encode(&self.0)
    }

    pub fn to_bytes(self) -> [u8; 32] {
        self.0
    }

    pub fn as_bytes(&self) -> &[u8; 32] {
        &self.0
    }

    pub fn from_signed_packet(packet: &SignedPacket) -> Self {
        Self(packet.public_key().to_bytes())
    }
}

impl fmt::Display for PublicKeyBytes {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.to_z32())
    }
}

impl fmt::Debug for PublicKeyBytes {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "PublicKeyBytes({})", self.to_z32())
    }
}

impl From<pkarr::PublicKey> for PublicKeyBytes {
    fn from(value: pkarr::PublicKey) -> Self {
        Self(value.to_bytes())
    }
}

impl TryFrom<PublicKeyBytes> for pkarr::PublicKey {
    type Error = anyhow::Error;
    fn try_from(value: PublicKeyBytes) -> Result<Self, Self::Error> {
        pkarr::PublicKey::try_from(&value.0).map_err(anyhow::Error::from)
    }
}

impl FromStr for PublicKeyBytes {
    type Err = anyhow::Error;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Self::from_z32(s)
    }
}

impl AsRef<[u8; 32]> for PublicKeyBytes {
    fn as_ref(&self) -> &[u8; 32] {
        &self.0
    }
}

pub fn signed_packet_to_hickory_message(signed_packet: &SignedPacket) -> Result<Message> {
    let encoded = signed_packet.encoded_packet();
    let message = Message::from_bytes(&encoded)?;
    Ok(message)
}

pub fn signed_packet_to_hickory_records_without_origin(
    signed_packet: &SignedPacket,
    filter: impl Fn(&Record) -> bool,
) -> Result<(Label, BTreeMap<RrKey, Arc<RecordSet>>)> {
    let common_zone = Label::from_utf8(&signed_packet.public_key().to_z32())?;
    let mut message = signed_packet_to_hickory_message(signed_packet)?;
    let answers = message.take_answers();
    let mut output: BTreeMap<RrKey, Arc<RecordSet>> = BTreeMap::new();
    for mut record in answers.into_iter() {
        // disallow SOA and NS records
        if matches!(record.record_type(), RecordType::SOA | RecordType::NS) {
            continue;
        }
        // expect the z32 encoded pubkey as root name
        let name = record.name();
        if name.num_labels() < 1 {
            continue;
        }
        let zone = name.iter().last().unwrap().into_label()?;
        if zone != common_zone {
            continue;
        }
        if !filter(&record) {
            continue;
        }

        let name_without_zone =
            Name::from_labels(name.iter().take(name.num_labels() as usize - 1))?;
        record.set_name(name_without_zone);

        let rrkey = RrKey::new(record.name().into(), record.record_type());
        match output.entry(rrkey) {
            btree_map::Entry::Vacant(e) => {
                let set: RecordSet = record.into();
                e.insert(Arc::new(set));
            }
            btree_map::Entry::Occupied(mut e) => {
                let set = e.get_mut();
                let serial = set.serial();
                // safe because we just created the arc and are sync iterating
                Arc::get_mut(set).unwrap().insert(record, serial);
            }
        }
    }
    Ok((common_zone, output))
}

pub fn record_set_append_origin(
    input: &RecordSet,
    origin: &Name,
    serial: u32,
) -> Result<RecordSet> {
    let new_name = input.name().clone().append_name(origin)?;
    let mut output = RecordSet::new(&new_name, input.record_type(), serial);
    // TODO: less clones
    for record in input.records_without_rrsigs() {
        let mut record = record.clone();
        record.set_name(new_name.clone());
        output.insert(record, serial);
    }
    Ok(output)
}