mod encryption;
use std::{
    fmt::{Debug, Display},
    hash::Hash,
    str::FromStr,
    sync::Mutex,
    time::Duration,
};
pub use ed25519_dalek::Signature;
use ed25519_dalek::{SignatureError, SigningKey, VerifyingKey};
use once_cell::sync::OnceCell;
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use ssh_key::LineEnding;
use ttl_cache::TtlCache;
pub use self::encryption::SharedSecret;
use self::encryption::{public_ed_box, secret_ed_box};
#[derive(Debug)]
struct CryptoKeys {
    verifying_key: VerifyingKey,
    crypto_box: crypto_box::PublicKey,
}
impl CryptoKeys {
    fn new(verifying_key: VerifyingKey) -> Self {
        let crypto_box = public_ed_box(&verifying_key);
        Self {
            verifying_key,
            crypto_box,
        }
    }
}
const KEY_CACHE_TTL: Duration = Duration::from_secs(60);
const KEY_CACHE_CAPACITY: usize = 1024 * 16;
static KEY_CACHE: OnceCell<Mutex<TtlCache<[u8; 32], CryptoKeys>>> = OnceCell::new();
fn lock_key_cache() -> std::sync::MutexGuard<'static, TtlCache<[u8; 32], CryptoKeys>> {
    let mutex = KEY_CACHE.get_or_init(|| Mutex::new(TtlCache::new(KEY_CACHE_CAPACITY)));
    mutex.lock().expect("not poisoned")
}
fn get_or_create_crypto_keys<T>(
    key: &[u8; 32],
    f: impl Fn(&CryptoKeys) -> T,
) -> std::result::Result<T, SignatureError> {
    let mut state = lock_key_cache();
    Ok(match state.entry(*key) {
        ttl_cache::Entry::Occupied(entry) => {
            f(entry.get())
        }
        ttl_cache::Entry::Vacant(entry) => {
            let vk = VerifyingKey::from_bytes(key)?;
            let item = CryptoKeys::new(vk);
            let item = entry.insert(item, KEY_CACHE_TTL);
            f(item)
        }
    })
}
#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
pub struct PublicKey([u8; 32]);
pub type NodeId = PublicKey;
impl Hash for PublicKey {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.0.hash(state);
    }
}
impl Serialize for PublicKey {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        if serializer.is_human_readable() {
            serializer.serialize_str(&self.to_string())
        } else {
            self.0.serialize(serializer)
        }
    }
}
impl<'de> Deserialize<'de> for PublicKey {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        if deserializer.is_human_readable() {
            let s = String::deserialize(deserializer)?;
            Self::from_str(&s).map_err(serde::de::Error::custom)
        } else {
            let data: [u8; 32] = serde::Deserialize::deserialize(deserializer)?;
            Self::try_from(data.as_ref()).map_err(serde::de::Error::custom)
        }
    }
}
impl PublicKey {
    pub fn as_bytes(&self) -> &[u8; 32] {
        &self.0
    }
    fn public(&self) -> VerifyingKey {
        get_or_create_crypto_keys(&self.0, |item| item.verifying_key).expect("key has been checked")
    }
    fn public_crypto_box(&self) -> crypto_box::PublicKey {
        get_or_create_crypto_keys(&self.0, |item| item.crypto_box.clone())
            .expect("key has been checked")
    }
    pub fn from_bytes(bytes: &[u8; 32]) -> Result<Self, SignatureError> {
        get_or_create_crypto_keys(bytes, |item| item.verifying_key)?;
        Ok(Self(*bytes))
    }
    pub fn verify(&self, message: &[u8], signature: &Signature) -> Result<(), SignatureError> {
        self.public().verify_strict(message, signature)
    }
    pub fn fmt_short(&self) -> String {
        data_encoding::HEXLOWER.encode(&self.as_bytes()[..5])
    }
    pub const LENGTH: usize = ed25519_dalek::PUBLIC_KEY_LENGTH;
}
impl TryFrom<&[u8]> for PublicKey {
    type Error = SignatureError;
    #[inline]
    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
        Ok(match <[u8; 32]>::try_from(bytes) {
            Ok(bytes) => {
                Self::from_bytes(&bytes)?
            }
            Err(_) => {
                let vk = VerifyingKey::try_from(bytes)?;
                vk.into()
            }
        })
    }
}
impl TryFrom<&[u8; 32]> for PublicKey {
    type Error = SignatureError;
    #[inline]
    fn try_from(bytes: &[u8; 32]) -> Result<Self, Self::Error> {
        Self::from_bytes(bytes)
    }
}
impl AsRef<[u8]> for PublicKey {
    fn as_ref(&self) -> &[u8] {
        self.as_bytes()
    }
}
impl From<VerifyingKey> for PublicKey {
    fn from(verifying_key: VerifyingKey) -> Self {
        let item = CryptoKeys::new(verifying_key);
        let key = *verifying_key.as_bytes();
        let mut table = lock_key_cache();
        table.insert(key, item, KEY_CACHE_TTL);
        PublicKey(key)
    }
}
impl Debug for PublicKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "PublicKey({})",
            data_encoding::HEXLOWER.encode(self.as_bytes())
        )
    }
}
impl Display for PublicKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", data_encoding::HEXLOWER.encode(self.as_bytes()))
    }
}
#[derive(thiserror::Error, Debug)]
pub enum KeyParsingError {
    #[error("decoding: {0}")]
    Decode(#[from] data_encoding::DecodeError),
    #[error("key: {0}")]
    Key(#[from] ed25519_dalek::SignatureError),
    #[error("invalid length")]
    DecodeInvalidLength,
}
impl FromStr for PublicKey {
    type Err = KeyParsingError;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let bytes = decode_base32_hex(s)?;
        Ok(Self::try_from(&bytes)?)
    }
}
#[derive(Clone)]
pub struct SecretKey {
    secret: SigningKey,
    secret_crypto_box: OnceCell<crypto_box::SecretKey>,
}
impl Debug for SecretKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "SecretKey(..)")
    }
}
impl Display for SecretKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}",
            data_encoding::HEXLOWER.encode(self.secret.as_bytes())
        )
    }
}
impl FromStr for SecretKey {
    type Err = KeyParsingError;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let bytes = decode_base32_hex(s)?;
        Ok(SecretKey::from(bytes))
    }
}
impl Serialize for SecretKey {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        self.secret.serialize(serializer)
    }
}
impl<'de> Deserialize<'de> for SecretKey {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let secret = SigningKey::deserialize(deserializer)?;
        Ok(secret.into())
    }
}
impl SecretKey {
    pub fn public(&self) -> PublicKey {
        self.secret.verifying_key().into()
    }
    pub fn generate() -> Self {
        let mut rng = rand::rngs::OsRng;
        Self::generate_with_rng(&mut rng)
    }
    pub fn generate_with_rng<R: CryptoRngCore + ?Sized>(csprng: &mut R) -> Self {
        let secret = SigningKey::generate(csprng);
        Self {
            secret,
            secret_crypto_box: OnceCell::default(),
        }
    }
    pub fn to_openssh(&self) -> ssh_key::Result<zeroize::Zeroizing<String>> {
        let ckey = ssh_key::private::Ed25519Keypair {
            public: self.secret.verifying_key().into(),
            private: self.secret.clone().into(),
        };
        ssh_key::private::PrivateKey::from(ckey).to_openssh(LineEnding::default())
    }
    pub fn try_from_openssh<T: AsRef<[u8]>>(data: T) -> anyhow::Result<Self> {
        let ser_key = ssh_key::private::PrivateKey::from_openssh(data)?;
        match ser_key.key_data() {
            ssh_key::private::KeypairData::Ed25519(kp) => Ok(SecretKey {
                secret: kp.private.clone().into(),
                secret_crypto_box: OnceCell::default(),
            }),
            _ => anyhow::bail!("invalid key format"),
        }
    }
    pub fn sign(&self, msg: &[u8]) -> Signature {
        use ed25519_dalek::Signer;
        self.secret.sign(msg)
    }
    pub fn to_bytes(&self) -> [u8; 32] {
        self.secret.to_bytes()
    }
    pub fn from_bytes(bytes: &[u8; 32]) -> Self {
        let secret = SigningKey::from_bytes(bytes);
        secret.into()
    }
    fn secret_crypto_box(&self) -> &crypto_box::SecretKey {
        self.secret_crypto_box
            .get_or_init(|| secret_ed_box(&self.secret))
    }
}
impl From<SigningKey> for SecretKey {
    fn from(secret: SigningKey) -> Self {
        SecretKey {
            secret,
            secret_crypto_box: OnceCell::default(),
        }
    }
}
impl From<[u8; 32]> for SecretKey {
    fn from(value: [u8; 32]) -> Self {
        Self::from_bytes(&value)
    }
}
impl TryFrom<&[u8]> for SecretKey {
    type Error = SignatureError;
    #[inline]
    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
        let secret = SigningKey::try_from(bytes)?;
        Ok(secret.into())
    }
}
fn decode_base32_hex(s: &str) -> Result<[u8; 32], KeyParsingError> {
    let mut bytes = [0u8; 32];
    let res = if s.len() == PublicKey::LENGTH * 2 {
        data_encoding::HEXLOWER.decode_mut(s.as_bytes(), &mut bytes)
    } else {
        data_encoding::BASE32_NOPAD.decode_mut(s.to_ascii_uppercase().as_bytes(), &mut bytes)
    };
    match res {
        Ok(len) => {
            if len != PublicKey::LENGTH {
                return Err(KeyParsingError::DecodeInvalidLength);
            }
        }
        Err(partial) => return Err(partial.error.into()),
    }
    Ok(bytes)
}
#[cfg(test)]
mod tests {
    use iroh_test::{assert_eq_hex, hexdump::parse_hexdump};
    use super::*;
    #[test]
    fn test_public_key_postcard() {
        let public_key =
            PublicKey::from_str("ae58ff8833241ac82d6ff7611046ed67b5072d142c588d0063e942d9a75502b6")
                .unwrap();
        let bytes = postcard::to_stdvec(&public_key).unwrap();
        let expected =
            parse_hexdump("ae58ff8833241ac82d6ff7611046ed67b5072d142c588d0063e942d9a75502b6")
                .unwrap();
        assert_eq_hex!(bytes, expected);
    }
    #[test]
    fn test_secret_key_openssh_roundtrip() {
        let kp = SecretKey::generate();
        let ser = kp.to_openssh().unwrap();
        let de = SecretKey::try_from_openssh(&ser).unwrap();
        assert_eq!(kp.to_bytes(), de.to_bytes());
    }
    #[test]
    fn public_key_postcard() {
        let key = PublicKey::from_bytes(&[0; 32]).unwrap();
        let bytes = postcard::to_stdvec(&key).unwrap();
        let key2: PublicKey = postcard::from_bytes(&bytes).unwrap();
        assert_eq!(key, key2);
    }
    #[test]
    fn public_key_json() {
        let key = PublicKey::from_bytes(&[0; 32]).unwrap();
        let bytes = serde_json::to_string(&key).unwrap();
        let key2: PublicKey = serde_json::from_str(&bytes).unwrap();
        assert_eq!(key, key2);
    }
    #[test]
    fn test_display_from_str() {
        let key = SecretKey::generate();
        assert_eq!(
            SecretKey::from_str(&key.to_string()).unwrap().to_bytes(),
            key.to_bytes()
        );
        assert_eq!(
            PublicKey::from_str(&key.public().to_string()).unwrap(),
            key.public()
        );
    }
    #[test]
    fn test_key_creation_cache() {
        let random_verifying_key = || {
            let sk = SigningKey::generate(&mut rand::thread_rng());
            sk.verifying_key()
        };
        let random_public_key = || random_verifying_key().to_bytes();
        let k1 = random_public_key();
        let _key = PublicKey::from_bytes(&k1).unwrap();
        assert!(lock_key_cache().contains_key(&k1));
        let k2 = random_public_key();
        let _key = PublicKey::try_from(&k2).unwrap();
        assert!(lock_key_cache().contains_key(&k2));
        let k3 = random_public_key();
        let _key = PublicKey::try_from(k3.as_slice()).unwrap();
        assert!(lock_key_cache().contains_key(&k3));
        let k4 = random_verifying_key();
        let _key = PublicKey::from(k4);
        assert!(lock_key_cache().contains_key(k4.as_bytes()));
        let k5 = random_verifying_key();
        let bytes = postcard::to_stdvec(&k5).unwrap();
        let _key: PublicKey = postcard::from_bytes(&bytes[1..]).unwrap();
        assert!(lock_key_cache().contains_key(k5.as_bytes()));
    }
}