use std::sync::Arc;

use pyo3::exceptions::{PyAssertionError, PyValueError};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyAny, PyDict, PyString};
use pyo3::{PyTraverseError, PyVisit, intern};

use crate::PydanticUseDefault;
use crate::errors::{
    ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ToErrorValue, ValError, ValResult,
    ValidationError,
};
use crate::input::Input;
use crate::py_gc::PyGcTraverse;
use crate::tools::{SchemaDict, function_name, safe_repr};

use super::generator::InternalValidator;
use super::{
    BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputType, ValidationState, Validator,
    build_validator,
};

struct FunctionInfo {
    /// The actual function object that will get called
    pub function: Py<PyAny>,
    pub field_name: Option<Py<PyString>>,
    pub info_arg: bool,
}

fn destructure_function_schema(schema: &Bound<'_, PyDict>) -> PyResult<FunctionInfo> {
    let func_dict: Bound<'_, PyDict> = schema.get_as_req(intern!(schema.py(), "function"))?;
    let function = func_dict.get_as_req(intern!(schema.py(), "function"))?;
    let func_type: Bound<'_, PyString> = func_dict.get_as_req(intern!(schema.py(), "type"))?;
    let info_arg = match func_type.to_str()? {
        "with-info" => true,
        "no-info" => false,
        _ => unreachable!(),
    };
    let field_name = func_dict.get_as(intern!(schema.py(), "field_name"))?;
    Ok(FunctionInfo {
        function,
        field_name,
        info_arg,
    })
}

macro_rules! impl_build {
    ($impl_name:ident, $name:literal) => {
        impl BuildValidator for $impl_name {
            const EXPECTED_TYPE: &'static str = $name;
            fn build(
                schema: &Bound<'_, PyDict>,
                config: Option<&Bound<'_, PyDict>>,
                definitions: &mut DefinitionsBuilder<Arc<CombinedValidator>>,
            ) -> PyResult<Arc<CombinedValidator>> {
                let py = schema.py();
                let validator = build_validator(&schema.get_as_req(intern!(py, "schema"))?, config, definitions)?;
                let func_info = destructure_function_schema(schema)?;
                let name = format!(
                    "{}[{}(), {}]",
                    $name,
                    function_name(func_info.function.bind(py))?,
                    validator.get_name()
                );
                Ok(Arc::new(
                    Self {
                        validator,
                        func: func_info.function,
                        config: match config {
                            Some(c) => c.clone().into(),
                            None => py.None(),
                        },
                        name,
                        field_name: func_info.field_name,
                        info_arg: func_info.info_arg,
                    }
                    .into(),
                ))
            }
        }
    };
}

#[derive(Debug)]
pub struct FunctionBeforeValidator {
    validator: Arc<CombinedValidator>,
    func: Py<PyAny>,
    config: Py<PyAny>,
    name: String,
    field_name: Option<Py<PyString>>,
    info_arg: bool,
}

impl_build!(FunctionBeforeValidator, "function-before");

impl FunctionBeforeValidator {
    fn _validate<'s, 'py>(
        &'s self,
        call: impl FnOnce(Bound<'py, PyAny>, &mut ValidationState<'_, 'py>) -> ValResult<Py<PyAny>>,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &'s mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let r = if self.info_arg {
            let field_name = state
                .field_name()
                .cloned()
                .map(Bound::unbind)
                .or_else(|| self.field_name.clone());
            let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
            self.func.call1(py, (input.to_object(py)?, info))
        } else {
            self.func.call1(py, (input.to_object(py)?,))
        };
        let value = r.map_err(|e| convert_err(py, e, input))?;
        call(value.into_bound(py), state)
    }
}

impl_py_gc_traverse!(FunctionBeforeValidator {
    validator,
    func,
    config
});

impl Validator for FunctionBeforeValidator {
    fn validate<'py>(
        &self,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let validate = |v, s: &mut ValidationState<'_, 'py>| self.validator.validate(py, &v, s);
        #[allow(clippy::used_underscore_items)]
        self._validate(validate, py, input, state)
    }
    fn validate_assignment<'py>(
        &self,
        py: Python<'py>,
        obj: &Bound<'py, PyAny>,
        field_name: &PyBackedStr,
        field_value: &Bound<'py, PyAny>,
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let validate = move |v, s: &mut ValidationState<'_, 'py>| {
            self.validator.validate_assignment(py, &v, field_name, field_value, s)
        };
        #[allow(clippy::used_underscore_items)]
        self._validate(validate, py, obj, state)
    }

    fn get_name(&self) -> &str {
        &self.name
    }
}

#[derive(Debug)]
pub struct FunctionAfterValidator {
    validator: Arc<CombinedValidator>,
    func: Py<PyAny>,
    config: Py<PyAny>,
    name: String,
    field_name: Option<Py<PyString>>,
    info_arg: bool,
}

impl_build!(FunctionAfterValidator, "function-after");

impl FunctionAfterValidator {
    fn _validate<'py, I: Input<'py> + ?Sized>(
        &self,
        call: impl FnOnce(&I, &mut ValidationState<'_, 'py>) -> ValResult<Py<PyAny>>,
        py: Python<'py>,
        input: &I,
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let v = call(input, state)?;
        let r = if self.info_arg {
            let field_name = state
                .field_name()
                .cloned()
                .map(Bound::unbind)
                .or_else(|| self.field_name.clone());
            let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
            self.func.call1(py, (v, info))
        } else {
            self.func.call1(py, (v,))
        };
        r.map_err(|e| convert_err(py, e, input))
    }
}

impl_py_gc_traverse!(FunctionAfterValidator {
    validator,
    func,
    config
});

impl Validator for FunctionAfterValidator {
    fn validate<'py>(
        &self,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let validate = |v: &_, s: &mut ValidationState<'_, 'py>| self.validator.validate(py, v, s);
        #[allow(clippy::used_underscore_items)]
        self._validate(validate, py, input, state)
    }
    fn validate_assignment<'py>(
        &self,
        py: Python<'py>,
        obj: &Bound<'py, PyAny>,
        field_name: &PyBackedStr,
        field_value: &Bound<'py, PyAny>,
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let validate = move |v: &Bound<'py, PyAny>, s: &mut ValidationState<'_, 'py>| {
            self.validator.validate_assignment(py, v, field_name, field_value, s)
        };
        #[allow(clippy::used_underscore_items)]
        self._validate(validate, py, obj, state)
    }

    fn get_name(&self) -> &str {
        &self.name
    }
}

#[derive(Debug, Clone)]
pub struct FunctionPlainValidator {
    func: Py<PyAny>,
    config: Py<PyAny>,
    name: String,
    field_name: Option<Py<PyString>>,
    info_arg: bool,
}

impl BuildValidator for FunctionPlainValidator {
    const EXPECTED_TYPE: &'static str = "function-plain";

    fn build(
        schema: &Bound<'_, PyDict>,
        config: Option<&Bound<'_, PyDict>>,
        _definitions: &mut DefinitionsBuilder<Arc<CombinedValidator>>,
    ) -> PyResult<Arc<CombinedValidator>> {
        let py = schema.py();
        let function_info = destructure_function_schema(schema)?;
        Ok(CombinedValidator::FunctionPlain(Self {
            func: function_info.function.clone(),
            config: match config {
                Some(c) => c.clone().into(),
                None => py.None(),
            },
            name: format!("function-plain[{}()]", function_name(function_info.function.bind(py))?),
            field_name: function_info.field_name.clone(),
            info_arg: function_info.info_arg,
        })
        .into())
    }
}

impl_py_gc_traverse!(FunctionPlainValidator { func, config });

impl Validator for FunctionPlainValidator {
    fn validate<'py>(
        &self,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let r = if self.info_arg {
            let field_name = state
                .field_name()
                .cloned()
                .map(Bound::unbind)
                .or_else(|| self.field_name.clone());
            let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
            self.func.call1(py, (input.to_object(py)?, info))
        } else {
            self.func.call1(py, (input.to_object(py)?,))
        };
        r.map_err(|e| convert_err(py, e, input))
    }

    fn get_name(&self) -> &str {
        &self.name
    }
}

#[derive(Debug)]
pub struct FunctionWrapValidator {
    validator: Arc<CombinedValidator>,
    func: Py<PyAny>,
    config: Py<PyAny>,
    name: String,
    field_name: Option<Py<PyString>>,
    info_arg: bool,
    hide_input_in_errors: bool,
    validation_error_cause: bool,
}

impl BuildValidator for FunctionWrapValidator {
    const EXPECTED_TYPE: &'static str = "function-wrap";

    fn build(
        schema: &Bound<'_, PyDict>,
        config: Option<&Bound<'_, PyDict>>,
        definitions: &mut DefinitionsBuilder<Arc<CombinedValidator>>,
    ) -> PyResult<Arc<CombinedValidator>> {
        let py = schema.py();
        let validator = build_validator(&schema.get_as_req(intern!(py, "schema"))?, config, definitions)?;
        let function_info = destructure_function_schema(schema)?;
        let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false);
        let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false);
        Ok(CombinedValidator::FunctionWrap(Self {
            validator,
            func: function_info.function.clone(),
            config: match config {
                Some(c) => c.clone().into(),
                None => py.None(),
            },
            name: format!("function-wrap[{}()]", function_name(function_info.function.bind(py))?),
            field_name: function_info.field_name.clone(),
            info_arg: function_info.info_arg,
            hide_input_in_errors,
            validation_error_cause,
        })
        .into())
    }
}

impl FunctionWrapValidator {
    fn _validate<'py>(
        &self,
        handler: &Bound<'_, PyAny>,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let r = if self.info_arg {
            let field_name = state
                .field_name()
                .cloned()
                .map(Bound::unbind)
                .or_else(|| self.field_name.clone());
            let info = ValidationInfo::new(py, state.extra(), &self.config, field_name);
            self.func.call1(py, (input.to_object(py)?, handler, info))
        } else {
            self.func.call1(py, (input.to_object(py)?, handler))
        };
        r.map_err(|e| convert_err(py, e, input))
    }
}

impl_py_gc_traverse!(FunctionWrapValidator {
    validator,
    func,
    config
});

impl Validator for FunctionWrapValidator {
    fn validate<'py>(
        &self,
        py: Python<'py>,
        input: &(impl Input<'py> + ?Sized),
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let handler = ValidatorCallable {
            validator: InternalValidator::new(
                "ValidatorCallable",
                self.validator.clone(),
                state,
                self.hide_input_in_errors,
                self.validation_error_cause,
            ),
        };
        let handler = Bound::new(py, handler)?;
        #[allow(clippy::used_underscore_items)]
        let result = self._validate(handler.as_any(), py, input, state);
        let handler = handler.borrow();
        state.exactness = handler.validator.exactness;
        state.fields_set_count = handler.validator.fields_set_count;
        result
    }

    fn validate_assignment<'py>(
        &self,
        py: Python<'py>,
        obj: &Bound<'py, PyAny>,
        field_name: &PyBackedStr,
        field_value: &Bound<'py, PyAny>,
        state: &mut ValidationState<'_, 'py>,
    ) -> ValResult<Py<PyAny>> {
        let handler = AssignmentValidatorCallable {
            validator: InternalValidator::new(
                "AssignmentValidatorCallable",
                self.validator.clone(),
                state,
                self.hide_input_in_errors,
                self.validation_error_cause,
            ),
            updated_field_name: field_name.clone(),
            updated_field_value: field_value.clone().into(),
        };
        #[allow(clippy::used_underscore_items)]
        self._validate(Bound::new(py, handler)?.as_any(), py, obj, state)
    }

    fn get_name(&self) -> &str {
        &self.name
    }
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug)]
struct ValidatorCallable {
    validator: InternalValidator,
}

#[pymethods]
impl ValidatorCallable {
    #[pyo3(signature = (input_value, outer_location=None))]
    fn __call__(
        &mut self,
        py: Python,
        input_value: &Bound<'_, PyAny>,
        outer_location: Option<&Bound<'_, PyAny>>,
    ) -> PyResult<Py<PyAny>> {
        let outer_location = outer_location.map(Into::into);
        self.validator.validate(py, input_value, outer_location)
    }

    fn __repr__(&self) -> String {
        format!("ValidatorCallable({:?})", self.validator)
    }

    fn __str__(&self) -> String {
        self.__repr__()
    }

    fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
        self.validator.py_gc_traverse(&visit)
    }
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[derive(Debug)]
struct AssignmentValidatorCallable {
    updated_field_name: PyBackedStr,
    updated_field_value: Py<PyAny>,
    validator: InternalValidator,
}

#[pymethods]
impl AssignmentValidatorCallable {
    #[pyo3(signature = (input_value, outer_location=None))]
    fn __call__(
        &mut self,
        py: Python,
        input_value: &Bound<'_, PyAny>,
        outer_location: Option<&Bound<'_, PyAny>>,
    ) -> PyResult<Py<PyAny>> {
        let outer_location = outer_location.map(Into::into);
        self.validator.validate_assignment(
            py,
            input_value,
            &self.updated_field_name,
            self.updated_field_value.bind(py),
            outer_location,
        )
    }

    fn __repr__(&self) -> String {
        format!("AssignmentValidatorCallable({:?})", self.validator)
    }

    fn __str__(&self) -> String {
        self.__repr__()
    }
}

macro_rules! py_err_string {
    ($py:expr, $py_err:expr, $error_value:expr, $type_member:ident, $input:ident) => {
        match $error_value.str() {
            Ok(py_string) => match py_string.to_str() {
                Ok(_) => ValError::new(
                    ErrorType::$type_member {
                        error: pyo3::IntoPyObjectExt::into_py_any($py_err, $py).ok(),
                        context: None,
                    },
                    $input,
                ),
                Err(e) => ValError::InternalErr(e),
            },
            Err(e) => ValError::InternalErr(e),
        }
    };
}

/// Only `ValueError` (including `PydanticCustomError` and `ValidationError`) and `AssertionError` are considered
/// as validation errors, `TypeError` is now considered as a runtime error to catch errors in function signatures
pub fn convert_err(py: Python<'_>, err: PyErr, input: impl ToErrorValue) -> ValError {
    if err.is_instance_of::<PyValueError>(py) {
        let error_value = err.value(py);
        if let Ok(pydantic_value_error) = error_value.cast::<PydanticCustomError>() {
            pydantic_value_error.get().clone().into_val_error(input)
        } else if let Ok(pydantic_error_type) = error_value.cast::<PydanticKnownError>() {
            pydantic_error_type.get().clone().into_val_error(input)
        } else if let Ok(validation_error) = err.value(py).cast::<ValidationError>() {
            validation_error.get().clone().into_val_error()
        } else {
            py_err_string!(py, err, error_value, ValueError, input)
        }
    } else if err.is_instance_of::<PyAssertionError>(py) {
        py_err_string!(py, err, err.value(py), AssertionError, input)
    } else if err.is_instance_of::<PydanticOmit>(py) {
        ValError::Omit
    } else if err.is_instance_of::<PydanticUseDefault>(py) {
        ValError::UseDefault
    } else {
        ValError::InternalErr(err)
    }
}

#[pyclass(module = "pydantic_core._pydantic_core", get_all)]
pub struct ValidationInfo {
    config: Py<PyAny>,
    context: Option<Py<PyAny>>,
    data: Option<Py<PyDict>>,
    field_name: Option<Py<PyString>>,
    mode: InputType,
}

impl_py_gc_traverse!(ValidationInfo {
    config,
    context,
    data,
    field_name
});

impl ValidationInfo {
    fn new(py: Python, extra: &Extra<'_, '_>, config: &Py<PyAny>, field_name: Option<Py<PyString>>) -> Self {
        Self {
            config: config.clone_ref(py),
            context: extra.context.map(|ctx| ctx.clone().into()),
            field_name,
            data: extra.data.as_ref().map(|data| data.clone().into()),
            mode: extra.input_type,
        }
    }

    fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
        self.py_gc_traverse(&visit)
    }

    fn __clear__(&mut self) {
        self.context = None;
    }
}

#[pymethods]
impl ValidationInfo {
    fn __repr__(&self, py: Python) -> PyResult<String> {
        let context = match self.context {
            Some(ref context) => safe_repr(context.bind(py)).to_string(),
            None => "None".into(),
        };
        let config = self.config.bind(py).repr()?;
        let data = match self.data {
            Some(ref data) => safe_repr(data.bind(py)).to_string(),
            None => "None".into(),
        };
        let field_name = match self.field_name {
            Some(ref field_name) => safe_repr(field_name.bind(py)).to_string(),
            None => "None".into(),
        };
        Ok(format!(
            "ValidationInfo(config={config}, context={context}, data={data}, field_name={field_name})"
        ))
    }
}
