//
// Syd: rock-solid application kernel
// src/kcov/mod.rs: KCOV userspace ABI shim for syzkaller
//
// Copyright (c) 2025, 2026 Ali Polatel <alip@chesswob.org>
// SPDX-License-Identifier: GPL-3.0

use std::{
    fmt,
    os::fd::{AsRawFd, RawFd},
    sync::{Arc, OnceLock, RwLock},
};

use nix::{errno::Errno, unistd::Pid};
use serde::{Serialize, Serializer};

use crate::hash::SydHashMap;

// KCOV ABI handlers
pub(crate) mod abi;

// KCOV API utilities
pub(crate) mod api;

// Thread-local sink describing where the live writer should send records.
#[derive(Clone, Copy, Debug)]
pub(crate) struct TlsSink {
    pub(crate) id: KcovId,
}

thread_local! {
    static TLS_SINK: RwLock<Option<TlsSink>> = const { RwLock::new(None) };
    static RECURSION_GUARD: RwLock<bool> = const { RwLock::new(false) };
}

pub(crate) fn get_tls_sink() -> Option<KcovId> {
    // Check recursion guard first.
    let guard = match RECURSION_GUARD.try_with(|g| *g.read().unwrap_or_else(|e| e.into_inner())) {
        Ok(g) => g,
        Err(_) => return None,
    };
    if guard {
        return None;
    }

    // Try TLS.
    if let Some(id) = TLS_SINK
        .try_with(|s| {
            s.read()
                .unwrap_or_else(|e| e.into_inner())
                .map(|sink| sink.id)
        })
        .ok()
        .flatten()
    {
        return Some(id);
    }

    None
}

pub(crate) fn set_tls_sink(id: KcovId) {
    let _ =
        TLS_SINK.try_with(|s| *s.write().unwrap_or_else(|e| e.into_inner()) = Some(TlsSink { id }));
}

pub(crate) fn clear_tls_sink() {
    let _ = TLS_SINK.try_with(|s| *s.write().unwrap_or_else(|e| e.into_inner()) = None);
}

// Global TID map: Tid -> (KcovId, is_remote)
#[expect(clippy::type_complexity)]
static KCOV_TID_MAP: OnceLock<RwLock<SydHashMap<Pid, (KcovId, bool)>>> = OnceLock::new();

#[expect(clippy::type_complexity)]
fn kcov_tid_map() -> &'static RwLock<SydHashMap<Pid, (KcovId, bool)>> {
    KCOV_TID_MAP.get_or_init(|| RwLock::new(SydHashMap::default()))
}

pub(crate) fn set_kcov_tid(tid: Pid, id: KcovId, is_remote: bool) {
    let mut map = kcov_tid_map().write().unwrap_or_else(|e| e.into_inner());

    // Don't overwrite a per-thread (non-remote) mapping with a remote one.
    if is_remote {
        if let Some(&(_existing_id, false)) = map.get(&tid) {
            return;
        }
    }

    map.insert(tid, (id, is_remote));
}

pub(crate) fn get_kcov_tid(tid: Pid) -> Option<KcovId> {
    let map = kcov_tid_map().read().unwrap_or_else(|e| e.into_inner());
    let entry = map.get(&tid).copied();
    match entry {
        Some((id, false)) => Some(id),
        Some((_id, true)) => None,
        None => None,
    }
}

// Remove the TID -> KcovId mapping for a given TID.
//
// Called on KCOV_DISABLE to prevent stale mappings from being reused after PID recycling.
pub(crate) fn remove_kcov_tid(tid: Pid) {
    let mut map = kcov_tid_map().write().unwrap_or_else(|e| e.into_inner());
    map.remove(&tid);
}

// Inherit KCOV mapping from parent to child on fork/clone/vfork.
//
// If parent has a KCOV mapping, copy it to the child.
pub(crate) fn inherit_kcov_tid(parent_tid: Pid, child_tid: Pid) {
    // Read the parent's full entry including the is_remote flag.
    let entry = {
        let map = kcov_tid_map().read().unwrap_or_else(|e| e.into_inner());
        map.get(&parent_tid).copied()
    };
    match entry {
        // Parent has a REMOTE_ENABLE mapping, do not inherit.
        Some((_id, true)) => {}
        // Parent has a per-thread ENABLE mapping. Inherit it.
        Some((id, false)) => {
            set_kcov_tid(child_tid, id, false);
        }
        None => {}
    }
}

//
// Public API
//

// KCOV modes (pc/cmp).
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub(crate) enum TraceMode {
    Pc,
    Cmp,
}

impl fmt::Display for TraceMode {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Pc => write!(f, "pc"),
            Self::Cmp => write!(f, "cmp"),
        }
    }
}

impl Serialize for TraceMode {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serializer.serialize_str(&self.to_string())
    }
}

// /sys/kernel/debug/kcov handle.
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct KcovId(u64);

impl KcovId {
    // Create a new KcovId.
    pub(crate) const fn new(id: u64) -> Self {
        Self(id)
    }
}

impl AsRawFd for KcovId {
    #[allow(clippy::disallowed_methods)]
    fn as_raw_fd(&self) -> RawFd {
        let map = crate::kcov::abi::kcov_reg()
            .read()
            .unwrap_or_else(|e| e.into_inner());
        map.get(self)
            .map(|ctx| ctx.syd_fd.as_raw_fd())
            .expect("BUG: missing ID in KCOV registry, report a bug!")
    }
}

// KCOV Device manager
pub(crate) struct Kcov {
    // KcovId to State map.
    map: RwLock<SydHashMap<KcovId, Arc<State>>>,
}

impl Kcov {
    pub(crate) fn new() -> Self {
        Self {
            map: RwLock::new(SydHashMap::default()),
        }
    }

    // Create KCOV instance.
    pub(crate) fn open(&self, kcov_id: u64) -> Result<(), Errno> {
        let kcov_id = KcovId(kcov_id);
        let state_arc = Arc::new(State::new());

        let mut map = self.map.write().unwrap_or_else(|e| e.into_inner());
        map.insert(kcov_id, state_arc);

        Ok(())
    }

    // KCOV_INIT_TRACE(words): Core tracks only phase, not size.
    pub(crate) fn init_trace(&self, kcov_id: KcovId, words: u64) -> Result<(), Errno> {
        self.get(kcov_id)?.init_trace(words)
    }

    // KCOV_ENABLE: Activate this KCOV id.
    pub(crate) fn enable(&self, id: KcovId, mode: TraceMode) -> Result<(), Errno> {
        let st = self.get(id)?;
        st.enable(mode)?;

        // Arm the TLS.
        set_tls_sink(id);

        Ok(())
    }

    // KCOV_DISABLE: Transition phase back to Init and clear TLS on this worker.
    pub(crate) fn disable(&self, id: KcovId) -> Result<(), Errno> {
        let st = self.get(id)?;
        st.disable()?;

        // Clear TLS for this worker thread (best-effort).
        clear_tls_sink();

        Ok(())
    }

    fn get(&self, kcov_id: KcovId) -> Result<Arc<State>, Errno> {
        let read_guard = self.map.read().unwrap_or_else(|e| e.into_inner());
        read_guard.get(&kcov_id).cloned().ok_or(Errno::EBADF)
    }
}

//
// Internals
//

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Phase {
    Disabled,
    Init,
    Enabled,
}

impl fmt::Display for Phase {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            Self::Disabled => "disabled",
            Self::Init => "init",
            Self::Enabled => "enabled",
        };
        f.write_str(s)
    }
}

impl Serialize for Phase {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serializer.serialize_str(&self.to_string())
    }
}

struct State {
    core: RwLock<Core>,
}

struct Core {
    mode: Option<TraceMode>,
    phase: Phase,
}

impl State {
    fn new() -> Self {
        Self {
            core: RwLock::new(Core {
                mode: None,
                phase: Phase::Disabled,
            }),
        }
    }

    fn init_trace(&self, words: u64) -> Result<(), Errno> {
        if words < 2 || words > (i32::MAX as u64) / 8 {
            return Err(Errno::EINVAL);
        }

        let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());
        if core.phase != Phase::Disabled {
            return Err(Errno::EBUSY);
        }

        core.mode = None;
        core.phase = Phase::Init;

        Ok(())
    }

    fn enable(&self, mode: TraceMode) -> Result<(), Errno> {
        let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());

        match core.phase {
            Phase::Init => {
                core.mode = Some(mode);
                core.phase = Phase::Enabled;
                Ok(())
            }

            // Idempotent enable:
            // Already enabled with the same mode succeeds.
            Phase::Enabled if core.mode == Some(mode) => Ok(()),

            _ => Err(Errno::EBUSY),
        }
    }

    fn disable(&self) -> Result<(), Errno> {
        let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());
        core.phase = Phase::Init;
        Ok(())
    }
}
