#!/usr/bin/env python3
"""Group PostgreSQL slow-log statements by normalized fingerprint."""

from __future__ import annotations

import argparse
import csv
import re
import sys
from dataclasses import dataclass
from pathlib import Path


DURATION_RE = re.compile(
    r"duration:\s+(?P<ms>[0-9]+(?:\.[0-9]+)?)\s+ms\s+"
    r"(?:(?:statement)|(?:execute\s+[^:]+)|(?:parse\s+[^:]+)|(?:bind\s+[^:]+)):\s*(?P<sql>.*)",
    re.IGNORECASE,
)

LOG_PREFIX_RE = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}|^\s*(?:LOG|ERROR|DETAIL|STATEMENT|WARNING|HINT|CONTEXT):\s")
STRING_RE = re.compile(r"'(?:''|[^'])*'")
NUMBER_RE = re.compile(r"\b\d+(?:\.\d+)?\b")
IN_LIST_RE = re.compile(r"\bIN\s*\((?:\s*\?\s*,?)+\)", re.IGNORECASE)
WHITESPACE_RE = re.compile(r"\s+")


@dataclass
class Stat:
    calls: int = 0
    total_ms: float = 0.0
    max_ms: float = 0.0
    example: str = ""

    def add(self, duration_ms: float, sql: str, show_examples: bool) -> None:
        self.calls += 1
        self.total_ms += duration_ms
        if duration_ms >= self.max_ms:
            self.max_ms = duration_ms
            self.example = compact(sql, 240) if show_examples else ""

    @property
    def mean_ms(self) -> float:
        return self.total_ms / self.calls if self.calls else 0.0


def compact(text: str, limit: int) -> str:
    text = WHITESPACE_RE.sub(" ", text).strip()
    if len(text) <= limit:
        return text
    return text[: limit - 3].rstrip() + "..."


def normalize(sql: str) -> str:
    sql = sql.strip()
    sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
    sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
    sql = STRING_RE.sub("?", sql)
    sql = NUMBER_RE.sub("?", sql)
    sql = IN_LIST_RE.sub("IN (?)", sql)
    sql = WHITESPACE_RE.sub(" ", sql).strip()
    return sql


def iter_entries(paths: list[Path]):
    current_duration: float | None = None
    current_sql: list[str] = []
    seen_log_prefix = False  # True once any prefixed line (timestamp/keyword) is seen

    def flush():
        nonlocal current_duration, current_sql
        if current_duration is not None and current_sql:
            sql = "\n".join(current_sql).strip()
            if sql:
                yield current_duration, sql
        current_duration = None
        current_sql = []

    for path in paths:
        try:
            stream = sys.stdin if str(path) == "-" else path.open("r", encoding="utf-8", errors="replace")
        except OSError as exc:
            print(f"warning: cannot open {path}: {exc}", file=sys.stderr)
            continue

        with stream:
            for raw_line in stream:
                line = raw_line.rstrip("\n")
                is_log_line = bool(LOG_PREFIX_RE.search(line))
                if is_log_line:
                    seen_log_prefix = True
                match = DURATION_RE.search(line)
                # In bare-format logs (no prefix ever seen), every DURATION_RE match is a
                # new entry. In prefixed-format logs, only start a new entry when the line
                # also carries a log prefix, so SQL continuation lines that happen to contain
                # "duration: N ms  statement:" text are not misidentified.
                if match and (current_duration is None or not seen_log_prefix or is_log_line):
                    yield from flush()
                    current_duration = float(match.group("ms"))
                    current_sql = [match.group("sql")]
                    continue

                if current_duration is not None:
                    if is_log_line and "statement:" not in line.lower():
                        yield from flush()
                    else:
                        current_sql.append(line)

    yield from flush()


def write_rows(rows, csv_path: Path | None, show_examples: bool) -> None:
    if not csv_path:
        return
    with csv_path.open("w", newline="", encoding="utf-8") as handle:
        writer = csv.writer(handle)
        header = ["rank", "calls", "total_ms", "mean_ms", "max_ms", "fingerprint"]
        if show_examples:
            header.append("example")
        writer.writerow(header)
        for rank, (fingerprint, stat) in enumerate(rows, 1):
            row = [
                rank,
                stat.calls,
                f"{stat.total_ms:.3f}",
                f"{stat.mean_ms:.3f}",
                f"{stat.max_ms:.3f}",
                fingerprint,
            ]
            if show_examples:
                row.append(stat.example)
            writer.writerow(row)


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("logs", nargs="+", type=Path, help="PostgreSQL log files, or '-' for stdin")
    parser.add_argument("--top", type=int, default=20, help="number of fingerprints to print")
    parser.add_argument("--csv", type=Path, help="optional CSV output path")
    parser.add_argument(
        "--show-examples",
        action="store_true",
        help="print and export raw SQL examples; use only after logs are redacted",
    )
    parser.add_argument(
        "--min-duration",
        type=float,
        default=0.0,
        metavar="MS",
        help="skip statements faster than this threshold in milliseconds (default: 0, include all)",
    )
    args = parser.parse_args()

    stats: dict[str, Stat] = {}
    raw_count = 0
    included_count = 0

    for duration_ms, sql in iter_entries(args.logs):
        raw_count += 1
        if duration_ms < args.min_duration:
            continue
        included_count += 1
        fingerprint = normalize(sql)
        stats.setdefault(fingerprint, Stat()).add(duration_ms, sql, args.show_examples)

    rows = sorted(stats.items(), key=lambda item: item[1].total_ms, reverse=True)
    write_rows(rows, args.csv, args.show_examples)

    if args.min_duration > 0:
        print(f"raw_statements={raw_count} included={included_count} fingerprints={len(rows)}")
    else:
        print(f"parsed_statements={raw_count} fingerprints={len(rows)}")
    top_rows = rows[: args.top]
    col_w = {
        "rank": max(4, len(str(len(top_rows)))),
        "calls": max(5, max((len(str(s.calls)) for _, s in top_rows), default=5)),
        "total_ms": 12,
        "mean_ms": 10,
        "max_ms": 10,
    }
    header = (
        f"{'rank':>{col_w['rank']}}  "
        f"{'calls':>{col_w['calls']}}  "
        f"{'total_ms':>{col_w['total_ms']}}  "
        f"{'mean_ms':>{col_w['mean_ms']}}  "
        f"{'max_ms':>{col_w['max_ms']}}  fingerprint"
    )
    print(header)
    print("-" * len(header))
    for rank, (fingerprint, stat) in enumerate(top_rows, 1):
        print(
            f"{rank:>{col_w['rank']}}  "
            f"{stat.calls:>{col_w['calls']}}  "
            f"{stat.total_ms:>{col_w['total_ms']}.3f}  "
            f"{stat.mean_ms:>{col_w['mean_ms']}.3f}  "
            f"{stat.max_ms:>{col_w['max_ms']}.3f}  "
            f"{compact(fingerprint, 120)}"
        )
        if args.show_examples:
            print(f"  example: {stat.example}")

    if not args.show_examples:
        print("raw_examples=hidden use --show-examples only after logs are redacted")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
