# MIT License Copyright (c) 2025 Agapitus Keyka Vigiliant
# MIT License Copyright (c) 2020 Tomasz Sobczyk
# GNU GPLv3 (C) 2012-2021 Niklas Fiekas <niklas.fiekas@backscattering.de>

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import sqlite3
from functools import lru_cache
from json import load as load_json
from re import compile as regex_compile
from subprocess import PIPE, Popen
from threading import Thread
from typing import Any

from chess import Board

# TODO: tidak usah gunakan module chess, kita kurang lebih hanya butuh atribut
# berikut: fen, set_fen, push_uci, dan copy. Ide, buat representasi papan
# dalam bentuk bitboard. Empat atribut tersebut seharusnya mudah dibuat, dan
# itu juga akan meringkas kode di encode_fen()

Info = dict[str, Any]

STARTING_FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
MATE_SCORE = 2**12
UCI_REGEX = regex_compile(r"^[a-h][1-8][a-h][1-8][pnbrqk]?|[PNBRQK]@[a-h][1-8]|0000\Z")
CHESS_FILE = {c: [8 * r + f for r in range(8)] for f, c in enumerate("abcdefgh")}
PIECE_MAP_ENCODE = {
    "P": (1, True),
    "N": (2, True),
    "B": (3, True),
    "R": (4, True),
    "Q": (5, True),
    "K": (6, True),
    "p": (1, False),
    "n": (2, False),
    "b": (3, False),
    "r": (4, False),
    "q": (5, False),
    "k": (6, False),
}


@lru_cache(maxsize=4096)
def encode_fen(fen: str) -> str:
    # Didasarkan oleh kode oleh Tomasz Sobczyk
    # https://github.com/official-stockfiPiecesh/nnue-pytorch/blob/master/lib/nnue_training_data_formats.h#L4615

    # Proses ini sebenarnya reversible, tetapi untuk masalah ini proses decode
    # tidak dibutuhkan. Kode bisa diringkas dengan membuat fungsi sebagai
    # atribut dari Board, dan daripada menggunakan self.fen().split(), ambil
    # informasi dari self.castling_xfen(), self.ep_square(), self.piece_type_at()
    # dan sebagainya. Cuma masalahnya, cukup fucked up untuk me-LRUCache-nya.

    splitted = fen.split()

    # ubah fen ke piece map; "sepadan" dengan chess.Board().piece_map()
    _i = 63
    piece_map = {}
    for _ in "/".join(_[::-1] for _ in splitted[0].split("/")):
        if _ == "/":
            continue
        elif _ in "12345678":
            _i -= int(_)
        else:
            piece_map[_i] = PIECE_MAP_ENCODE[_]
            _i -= 1

    turn_is_black = splitted[1] == "b"
    castling_right = splitted[2]
    _ep = splitted[3]
    ep_file = CHESS_FILE[_ep[0]] if _ep != "-" else []

    # enkode semua piece di papan
    nibble, occupancy, bitcount = 0, 0, 64
    for square in range(64):
        if square not in piece_map:
            occupancy = occupancy << 1 | 0
            continue

        occupancy = occupancy << 1 | 1
        bitcount += 4
        ptype, pcolor = piece_map[square]
        ptype = 2 * ptype - 1

        if ptype == 1:
            if _ep and square in ep_file:
                nibble = nibble << 4 | 12
            else:
                nibble = nibble << 4 | ptype - pcolor
        elif ptype == 7:
            if square == 0 and "Q" in castling_right:
                nibble = nibble << 4 | 13
            elif square == 7 and "K" in castling_right:
                nibble = nibble << 4 | 13
            elif square == 56 and "q" in castling_right:
                nibble = nibble << 4 | 14
            elif square == 63 and "k" in castling_right:
                nibble = nibble << 4 | 14
            else:
                nibble = nibble << 4 | ptype - pcolor
        elif ptype == 11:
            if turn_is_black and not pcolor:
                nibble = nibble << 4 | 15
            else:
                nibble = nibble << 4 | ptype - pcolor
        else:
            nibble = nibble << 4 | ptype - pcolor

    num = nibble << 64 | occupancy
    # ubah ke BLOB, num terlalu besar untuk SQLite
    bitcount = (bitcount + 7) // 8  # == ceil(log2(num))
    return num.to_bytes(bitcount)


def _parse_uci_info(text_info: str) -> Info:
    # modifikasi dari kode chess.engine._parse_uci_info
    # oleh Niklas Fiekas <niklas.fiekas@backscattering.de>

    info: Info = {}
    tokens = text_info.split(" ")
    try:
        while tokens:
            parameter = tokens.pop(0)

            if parameter == "string":
                info["string"] = " ".join(tokens)
                break
            elif parameter in [
                "depth",
                "seldepth",
                "nodes",
                "multipv",
                "currmovenumber",
                "hashfull",
                "nps",
                "tbhits",
                "cpuload",
            ]:
                info[parameter] = int(tokens.pop(0))
            elif parameter == "time":
                info["time"] = int(tokens.pop(0))
            elif parameter == "ebf":
                info["ebf"] = float(tokens.pop(0))
            elif parameter == "score":
                kind = tokens.pop(0)
                value = int(tokens.pop(0))
                if tokens and tokens[0] in ["lowerbound", "upperbound"]:
                    info[tokens.pop(0)] = True
                if kind == "cp":
                    info["score"] = value
                elif kind == "mate":
                    if value > 0:
                        info["score"] = MATE_SCORE - value
                    else:
                        info["score"] = -MATE_SCORE - value
                else:
                    raise Exception("Unknown score kind")
            elif parameter == "currmove":
                info["currmove"] = tokens.pop(0)
            elif parameter == "currline":
                if "currline" not in info:
                    info["currline"] = {}

                cpunr = int(tokens.pop(0))
                currline: list[str] = []
                info["currline"][cpunr] = currline

                while tokens and UCI_REGEX.match(tokens[0]):
                    currline.append(tokens.pop(0))
            elif parameter == "refutation":
                if "refutation" not in info:
                    info["refutation"] = {}

                refuted = tokens.pop(0)
                refuted_by: list[str] = []
                info["refutation"][refuted] = refuted_by

                while tokens and UCI_REGEX.match(tokens[0]):
                    refuted_by.append(tokens.pop(0))
            elif parameter == "pv":
                pv: list[str] = []
                info["pv"] = pv
                while tokens and UCI_REGEX.match(tokens[0]):
                    pv.append(tokens.pop(0))
            elif parameter == "wdl":
                info["wdl"] = (
                    int(tokens.pop(0)),
                    int(tokens.pop(0)),
                    int(tokens.pop(0)),
                )
    except ValueError, IndexError:
        raise ValueError("Exception when parsing info")
    return info


def _unparse_uci_info(info: Info) -> str:
    text = ["info"]
    for k, v in info.items():
        text.append(k)
        if k == "pv":
            text.extend(info["pv"])
        elif k == "score":
            if v > MATE_SCORE - 100:
                text.append(f"mate {MATE_SCORE - v}")
            elif v < 100 - MATE_SCORE:
                text.append(f"mate {-MATE_SCORE - v}")
            else:
                text.append(f"cp {v}")
        else:
            text.append(str(v))
    return " ".join(text)


class Database:
    def __init__(self, database: str) -> None:
        def dict_factory(cursor: sqlite3.Cursor, row: sqlite3.Row) -> dict[str, Any]:
            d = {}
            for idx, col in enumerate(cursor.description):
                d[col[0]] = row[idx]
            return d

        self.db = sqlite3.connect(database, check_same_thread=False)
        self.db.row_factory = dict_factory
        self.db.executescript(
            """
            PRAGMA journal_mode = wal;
            PRAGMA synchronous = normal;
            PRAGMA temp_store = memory;
            PRAGMA mmap_size = 30000000000;
            PRAGMA busy_timeout = 10000;

            PRAGMA wal_autocheckpoint;

            CREATE TABLE IF NOT EXISTS board(
                fen         BLOB    NOT NULL,
                multipv     INTEGER NOT NULL,
                depth       INTEGER NOT NULL,
                score       INTEGER NOT NULL,
                nodes       INTEGER NOT NULL,
                move        TEXT,
                PRIMARY KEY (fen, multipv)
                ) WITHOUT ROWID;
            """
        )

    def close(self) -> None:
        self.db.close()

    def commit(self) -> None:
        self.db.commit()

    def _get_moves(self, board: Board, depth: int) -> list[str]:
        # Dapatkan rangkaian best moves untuk posisi board
        # WARNING: variabel board bisa berubah!

        stt = "SELECT move FROM board WHERE fen=? AND multipv=1"
        move_stack = []

        while depth > 0:
            # dapatkan data singgahan
            result = self.db.execute(stt, (encode_fen(board.fen()),)).fetchone()
            if not result:
                break
            result = result["move"]
            move_stack.append(result)
            board.push_uci(result)
            depth -= 1
        return move_stack

    def get_info(
        self, root_board: Board, multipv: int, with_pv: bool = False
    ) -> Info | None:
        # Dapatkan info analisis terkait posisi board

        stt = """
            SELECT multipv, depth, score, nodes, move
            FROM board WHERE fen=? AND multipv=?
        """
        info = self.db.execute(stt, (encode_fen(root_board.fen()), multipv)).fetchone()  # type: Info | None

        if not info:
            return None

        # dapatkan daftar pv
        pv = info.pop("move")
        if with_pv:
            board = root_board.copy(stack=False)
            board.push_uci(pv)
            info["pv"] = [pv] + self._get_moves(board=board, depth=info["depth"] - 1)
        return info

    def cache_or_ignore(self, root_board: Board, info: Info) -> None:
        stt = """
            INSERT INTO board (fen, multipv, depth, score, nodes, move)
            VALUES (:fen, :multipv, :depth, :score, :nodes, :move)
            ON CONFLICT (fen, multipv) DO UPDATE SET
                depth = excluded.depth,
                score = excluded.score,
                nodes = excluded.nodes,
                move  = excluded.move
        """
        board = root_board.copy()
        info_ = info.copy()

        # singgah hasil analisis untuk posisi saat ini, dan gunakan hasil
        # sebagai hampiran analisis untuk rangkaian best moves di `pv`
        for move in info_["pv"]:
            # loop sampai depth = 0
            if info_["depth"] == 0:
                break

            # bandingkan dengan hasil singgahan
            old_info = self.get_info(board, info_["multipv"])
            if old_info and old_info["nodes"] > info_["nodes"]:
                # saya menggunakan nodes sebagai pembanding; alternatif berupa
                # `depth`, `time`, kombinasi mereka atau lainnya, bisa dipakai
                break

            # untuk iterasi pertama; induk
            info_["fen"] = encode_fen(board.fen())
            info_["move"] = move
            self.db.execute(stt, info_)

            # untuk semua iterasi berikutnya; keturunannya
            board.push_uci(move)
            info_["multipv"] = 1  # walau multipv induk mungkin !=1
            info_["score"] *= -1  # ubah sudut pandang score
            info_["depth"] -= 1  # kurangi depth

            # TODO: buat hampiran nilai yang lebih baik
            info_["nodes"] = 0

        self.commit()
        return


class UciProtocol:
    def __init__(self, settings_path: str):
        with open(settings_path) as f:
            settings = load_json(f)

        self.engine_path = settings.get("engine_path")
        self.db_path = settings.get("database_path", ":memory:")

        self.board = Board()
        self.quit = False

    def parse_input(self) -> None:
        stream = self.engine.stdin
        assert stream is not None  # agar mypy senang

        while True:
            command = input().strip()
            if command == "quit":
                self.quit = True
                break

            split = command.split(" ")
            if split[0] == "position":
                # perbarui posisi board

                i = split.index("moves") if "moves" in split else -1
                if split[1] == "startpos":
                    self.board.set_fen(STARTING_FEN)
                elif split[1] == "fen":
                    fen = " ".join(split[2:i])
                    self.board.set_fen(fen)
                if i > 0:
                    for move in split[i + 1 :]:
                        self.board.push_uci(move)

            stream.write(f"{command}\n")

    def parse_output(self) -> None:
        stream = self.engine.stdout
        assert stream is not None  # agar mypy senang

        while not self.quit:
            text = stream.readline().strip()
            if text == "":
                continue

            elif text[:4] == "info":
                if (
                    ("upperbound" in text)
                    or ("lowerbound" in text)
                    or ("pv" not in text)
                    or ("score" not in text)
                ):
                    # baris info ini tidak cocok untuk disinggah
                    continue

                info = _parse_uci_info(text)
                self.db.cache_or_ignore(self.board, info)
                cached = self.db.get_info(self.board, info["multipv"], with_pv=True)
                assert cached is not None  # agar mypy senang
                info.update(cached)
                text = _unparse_uci_info(info)

            elif text[:8] == "bestmove":
                # dapatkan bestmove dan ponder dari database
                # self.board tidak akan dipakai lagi, tidak perlu copy()
                moves = self.db._get_moves(self.board, depth=2)

                if len(moves) == 2:
                    text = f"bestmove {moves[0]} ponder {moves[1]}"
                elif len(moves) == 1:
                    text = f"bestmove {moves[0]}"
                # else, tampilkan apa yang diberikan mesin saja

            # Kode ini dijalankan (dan dikompilasi) dengan flag -u (unbuffered).
            # Jika cara itu tidak bisa dilakukan, ganti kode berikut dengan
            # print(text, flush=True)
            print(text)

    def start(self) -> None:
        self.engine = Popen(
            self.engine_path,
            stdin=PIPE,
            stdout=PIPE,
            universal_newlines=True,
            bufsize=1,
        )
        self.db = Database(self.db_path)

        thread = Thread(target=self.parse_output)
        thread.start()

        try:
            self.parse_input()
        except:
            self.quit = True
        finally:
            self.engine.terminate()
            self.db.close()
            thread.join()


if __name__ == "__main__":
    mm = UciProtocol(settings_path="./settings.json")
    mm.start()
