//
// jja: swiss army knife for chess file formats
// src/pgn.rs: Portable Game Notation utilities
//
// Copyright (c) 2023, 2024 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0-or-later

use std::{
    fmt::{self, Display, Formatter},
    fs::File,
    io::{stdout, Read, StdoutLock, Write},
    str::FromStr,
};

use anyhow::{bail, Context};
use indicatif::ProgressBar;
use pgcopy::Encoder;
use pgn_reader::{BufferedReader, RawHeader, SanPlus, Skip, Visitor};
use shakmaty::{
    fen::{Epd, Fen},
    CastlingMode, Chess, EnPassantMode, Position, PositionError,
};

use crate::{
    chess::serialize_chess,
    hash::{zobrist16_hash, zobrist32_hash, zobrist8_hash, zobrist_hash},
    stockfish::stockfish_hash,
    system::get_progress_bar,
    tr,
};

/// Output formats for `pgn_dump` function.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum OutputFormat {
    /// Represents the PostgreSQL binary output format.
    Binary,
    /// Represents the CSV output format.
    Csv,
    /// Represents the JSON output format.
    Json,
    /// Represents the EPD output format.
    Epd,
}

impl Display for OutputFormat {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            OutputFormat::Binary => write!(f, "BIN"),
            OutputFormat::Csv => write!(f, "CSV"),
            OutputFormat::Json => write!(f, "JSON"),
            OutputFormat::Epd => write!(f, "EPD"),
        }
    }
}

/// Dump elements for `pgn_dump` function.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum DumpElement {
    /// Represents a Zobrist8 hash of the position.
    Zobrist8,
    /// Represents a Zobrist16 hash of the position.
    Zobrist16,
    /// Represents a Zobrist32 hash of the position.
    Zobrist32,
    /// Represents a Zobrist64 hash of the position.
    Zobrist64,
    /// Represents a Stockfish compatible Zobrist64 hash of the position.
    Zobrist64SF,
    /// Represents a serialized array of the position.
    Position,
}

impl Display for DumpElement {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            DumpElement::Zobrist8 => write!(f, "zobrist8"),
            DumpElement::Zobrist16 => write!(f, "zobrist16"),
            DumpElement::Zobrist32 => write!(f, "zobrist8"),
            DumpElement::Zobrist64 => write!(f, "zobrist64"),
            DumpElement::Zobrist64SF => write!(f, "zobrist64sf"),
            DumpElement::Position => write!(f, "position"),
        }
    }
}

impl FromStr for DumpElement {
    type Err = ();

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "z8" | "zobrist8" => Ok(DumpElement::Zobrist8),
            "z16" | "zobrist16" => Ok(DumpElement::Zobrist16),
            "z32" | "zobrist32" => Ok(DumpElement::Zobrist32),
            "id" | "z64" | "zobrist64" => Ok(DumpElement::Zobrist64),
            "sf" | "z64sf" | "zobrist64sf" => Ok(DumpElement::Zobrist64SF),
            "p" | "pos" | "position" => Ok(DumpElement::Position),
            _ => Err(()),
        }
    }
}

/// `PositionTracker` is a struct that represents a chess position and tracks its validity.
///
/// # Fields
/// * `position`: A `Chess` struct representing the current chess position.
/// * `init`: A boolean that indicates whether the initial position of the game was processed.
/// * `valid`: A boolean that indicates whether the position is valid or not.
/// * `output`: Writer for all text output. If `PositionTracker::new` was called with
/// `None` as output argument, then this writer is going to direct to the standard output.
/// * `format`: A `Option<OutputFormat>` representing the optional output format.
pub struct PositionTracker<'a> {
    position: Chess,
    init: bool,
    valid: bool,
    output: Box<dyn Write + 'a>,
    format: Option<OutputFormat>,
    elements: Option<&'a Vec<DumpElement>>,
    progress_bar: Option<&'a ProgressBar>,
    postgres_enc: Option<Encoder<StdoutLock<'a>>>,
}

impl<'a> PositionTracker<'a> {
    /// Constructs a new `PositionTracker` with the default chess position and sets the `valid`
    /// field to `true`.
    ///
    /// # Arguments
    ///
    /// - `output: Option<Box<dyn Write + 'a>>`: Optional writer for text output, for Postgresql
    /// binary output use the encoder argument. If `None`, the text output will be sent to the
    /// standard output.
    /// - `format: Option<OutputFormat>`: Optional output format.
    /// - `elements: Option<DumpElement>`: Optional list of elements.
    /// - `encoder: Option<Encoder<StdoutLock<'a>>>`: Optional Postgresql encoder.
    /// - `progress_bar: Option<&ProgressBar>`: Optional progress bar.
    ///
    /// # Returns
    /// A new instance of `PositionTracker`.
    pub fn new(
        output: Option<Box<dyn Write + 'a>>,
        format: Option<OutputFormat>,
        elements: Option<&'a Vec<DumpElement>>,
        encoder: Option<Encoder<StdoutLock<'a>>>,
        progress_bar: Option<&'a ProgressBar>,
    ) -> Self {
        let output = match output {
            Some(file) => file,
            None => Box::new(stdout()),
        };

        PositionTracker {
            output,
            format,
            elements,
            progress_bar,
            postgres_enc: encoder,
            position: Chess::default(),
            init: false,
            valid: true,
        }
    }

    fn write_header(&mut self) -> anyhow::Result<()> {
        if let Some(encoder) = self.postgres_enc.as_mut() {
            encoder
                .write_header()
                .context(tr!("Failed to write PostgreSQL binary header"))?;
        }

        Ok(())
    }

    fn write_trailer(&mut self) -> anyhow::Result<()> {
        if let Some(encoder) = self.postgres_enc.as_mut() {
            encoder
                .write_trailer()
                .context(tr!("Failed to write PostgreSQL binary trailer"))?;
        }

        Ok(())
    }

    fn dump_position(&mut self) {
        if let Some(format) = self.format {
            if format == OutputFormat::Epd {
                // TODO: Add context from PGN headers into c0 & c1 fields like pgn-extract.
                // TODO: Use --elements to let the user pick which EPD fields to print.
                writeln!(
                    self.output,
                    "{} id {:#x};",
                    Epd::from_position(self.position.clone(), EnPassantMode::PseudoLegal),
                    zobrist_hash(&self.position)
                )
                .expect("epd write");

                if let Some(progress_bar) = &self.progress_bar {
                    progress_bar.inc(1);
                }
                return;
            }

            let mut elements = Vec::new();
            if let Some(dump_elements) = self.elements {
                for dump_element in dump_elements {
                    match dump_element {
                        DumpElement::Zobrist8 => {
                            elements.push(u64::from(zobrist8_hash(&self.position)));
                        }
                        DumpElement::Zobrist16 => {
                            elements.push(u64::from(zobrist16_hash(&self.position)));
                        }
                        DumpElement::Zobrist32 => {
                            elements.push(u64::from(zobrist32_hash(&self.position)));
                        }
                        DumpElement::Zobrist64 => {
                            elements.push(zobrist_hash(&self.position));
                        }
                        DumpElement::Zobrist64SF => {
                            elements.push(stockfish_hash(&self.position));
                        }
                        DumpElement::Position => {
                            elements.extend(serialize_chess(&self.position));
                        }
                    }
                }
            } else {
                elements.push(zobrist_hash(&self.position));
                elements.extend(serialize_chess(&self.position));
            }

            match format {
                OutputFormat::Binary => {
                    // SAFETY: output format guarantees postgres_enc is Some.
                    let enc = unsafe { self.postgres_enc.as_mut().unwrap_unchecked() };
                    enc.write_tuple(elements.len() as i16)
                        .expect("pgcopy write tuple");
                    for element in elements {
                        enc.write_bigint(element as i64)
                            .expect("pgcopy write element");
                    }
                }
                // TODO: Honour self.output rather than printing directly to standard output.
                // Currently there is no CLI argument to specify such output so until then we do
                // not implement it here either to keep things simple.
                OutputFormat::Csv => {
                    for (idx, element) in elements.into_iter().enumerate() {
                        print!("{}{}", if idx == 0 { "" } else { "," }, element as i64);
                    }
                    println!();
                }
                OutputFormat::Json => {
                    print!("[");
                    for (idx, element) in elements.into_iter().enumerate() {
                        print!("{}{}", if idx == 0 { "" } else { "," }, element);
                    }
                    println!("]");
                }
                // OutputFormat::Epd has already been handled above.
                _ => unreachable!(),
            }

            if let Some(progress_bar) = &self.progress_bar {
                progress_bar.inc(1);
            }
        }
    }
}

/// Implementing the `Visitor` trait for the `PositionTracker` struct.
/// This allows the struct to handle chess games represented in SAN notation,
/// including playing moves, handling variations, and ending the game.
///
/// The `Result` type is set to `Chess`.
impl Visitor for PositionTracker<'_> {
    type Result = Chess;

    fn header(&mut self, key: &[u8], value: RawHeader<'_>) {
        if key != b"FEN" {
            return;
        }

        // Support games from a non-standard starting position.
        let fen = match Fen::from_ascii(value.as_bytes()) {
            Ok(fen) => fen,
            Err(err) => {
                eprintln!(
                    "{}",
                    tr!(
                        "Skipping invalid FEN header: {} ({}).",
                        err,
                        format!("{:?}", value)
                    )
                );
                self.valid = false;
                return;
            }
        };

        self.position = match fen
            .into_position(CastlingMode::Chess960)
            .or_else(PositionError::ignore_invalid_ep_square)
            .or_else(PositionError::ignore_invalid_castling_rights)
        {
            Ok(pos) => pos,
            Err(err) => {
                eprintln!(
                    "{}",
                    tr!(
                        "Skipping illegal FEN header: {} ({}).",
                        err,
                        format!("{:?}", value)
                    )
                );
                self.valid = false;
                return;
            }
        };
        // Given multiple FEN headers (which is super-rare and formally
        // broken PGN), we process all of them for dump, and use the
        // last one for position tracker.
        self.dump_position();
        self.init = true;
    }

    fn end_headers(&mut self) -> Skip {
        if !self.init {
            self.dump_position();
            self.init = true;
        }
        Skip(!self.valid)
    }

    /// Processes a move in SAN notation and updates the position.
    ///
    /// If the position is invalid, this function does nothing.
    ///
    /// # Parameters
    /// * `san_plus`: A `SanPlus` containing the move in SAN notation.
    fn san(&mut self, san_plus: SanPlus) {
        if !self.valid {
            return;
        }

        let san = san_plus.san;
        let mov = match san.to_move(&self.position) {
            Ok(mov) => mov,
            Err(err) => {
                let epd = format!(
                    "{}",
                    Epd::from_position(self.position.clone(), EnPassantMode::PseudoLegal)
                );
                eprintln!(
                    "{}",
                    tr!("illegal SAN move `{}' in position `{}': {}", san, epd, err)
                );
                self.valid = false;
                return;
            }
        };

        self.position.play_unchecked(&mov);
        self.dump_position();
    }

    /// Handles the beginning of a variation.
    ///
    /// The implementation always returns `Skip(true)`, which causes the parser
    /// to stay in the mainline and ignore the variation.
    ///
    /// # Returns
    /// A `Skip` instance indicating whether to skip the variation or not.
    fn begin_variation(&mut self) -> Skip {
        Skip(true) // stay in the mainline
    }

    /// Handles the end of the game and resets the position.
    ///
    /// This function clones the current position and then resets the internal
    /// position to the default chess position.
    ///
    /// # Returns
    /// The `Chess` instance representing the final position of
    fn end_game(&mut self) -> Self::Result {
        self.valid = true;
        std::mem::take(&mut self.position)
    }
}

/// Converts a PGN string into an EPD string representing the final position.
///
/// This function takes a PGN string as input, processes the moves, and then
/// returns a string containing the final position in EPD format.
///
/// # Parameters
/// * `pgn`: A `&str` containing the PGN of the chess game.
///
/// # Returns
/// A `String` containing the final position of the game in EPD format.
pub fn pgn2epd(pgn: &str) -> String {
    let mut reader = BufferedReader::new_cursor(pgn);

    let mut tracker: PositionTracker<'_> = PositionTracker::new(None, None, None, None, None);
    let position = reader
        .read_game(&mut tracker)
        .expect("pgn")
        .expect("invalid pgn argument");

    format!(
        "{}",
        Epd::from_position(position, EnPassantMode::PseudoLegal)
    )
}

/// Converts a PGN file into a JSON stream of arrays which is a 6-sized array, the first element is
/// the Zobrist hash, the second to sixth elements are the serialized position as returned by
/// `jja::chess::serialize_chess` function or into a stream of CSV arrays which is again a 6-sized
/// array with the exact same structure as the JSON array except in CSV format the numbers are
/// cast into signed 64-bit numbers, whereas in JSON output they're unsigned 64-bit numbers.
/// Compressed PGN files are supported.
pub fn pgn_dump(
    file_name: &str,
    format: OutputFormat,
    elements: &Vec<DumpElement>,
) -> anyhow::Result<()> {
    eprintln!(
        "{}",
        tr!(
            "Dumping all positions in PGN file `{}' in format {} to standard output...",
            file_name,
            format
        )
    );

    let file = match File::open(file_name) {
        Ok(file) => file,
        Err(err) => {
            bail!(
                "{}",
                tr!("Failed to open PGN file `{}': {}", file_name, err)
            );
        }
    };

    let file_ext = std::path::Path::new(file_name).extension();
    let uncompressed: Box<dyn Read + Send> =
        if file_ext.map_or(false, |ext| ext.eq_ignore_ascii_case("zst")) {
            Box::new(match zstd::Decoder::new(file) {
                Ok(decoder) => decoder,
                Err(err) => {
                    bail!(
                        "{}",
                        tr!("Failed to open PGN file `{}': {}", file_name, err)
                    );
                }
            })
        } else if file_ext.map_or(false, |ext| ext.eq_ignore_ascii_case("bz2")) {
            Box::new(bzip2::read::MultiBzDecoder::new(file))
        } else if file_ext.map_or(false, |ext| ext.eq_ignore_ascii_case("xz")) {
            Box::new(xz2::read::XzDecoder::new(file))
        } else if file_ext.map_or(false, |ext| ext.eq_ignore_ascii_case("gz")) {
            Box::new(flate2::read::GzDecoder::new(file))
        } else if file_ext.map_or(false, |ext| ext.eq_ignore_ascii_case("lz4")) {
            Box::new(match lz4::Decoder::new(file) {
                Ok(decoder) => decoder,
                Err(err) => {
                    bail!(
                        "{}",
                        tr!("Failed to open PGN file `{}': {}", file_name, err)
                    );
                }
            })
        } else {
            Box::new(file)
        };

    let progress_bar = get_progress_bar(0);
    progress_bar.set_message(tr!("Dumping:"));

    let encoder: Option<Encoder<StdoutLock>> = match format {
        OutputFormat::Binary => Some(Encoder::new(stdout().lock())),
        _ => None,
    };

    let mut tracker = PositionTracker::new(
        None,
        Some(format),
        Some(elements),
        encoder,
        Some(&progress_bar),
    );
    tracker.write_header()?;
    BufferedReader::new(uncompressed).read_all(&mut tracker)?;
    tracker.write_trailer()?;

    let count = progress_bar.position();
    progress_bar.finish_and_clear();

    eprintln!(
        "{}",
        tr!(
            "Successfully dumped {} positions from PGN file `{}' in format {} to standard output.",
            count,
            file_name,
            format
        )
    );
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pgn2epd() {
        let pgn = "1. e4 e5 2. Nf3 Nc6 3. Bb5";
        let epd = pgn2epd(pgn);
        assert_eq!(
            epd,
            "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq -"
        );
    }
}
