"""Query the GitHub API and plot CI workflow run durations."""

import csv
import datetime
import json
import os
import urllib.error
import urllib.request

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np

_OWNER = "xadupre"
_REPO = "yet-another-onnxruntime-extensions"
_API_BASE = f"https://api.github.com/repos/{_OWNER}/{_REPO}"
_HEADERS = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
_CACHE_MAX_AGE_DAYS = 14
_USER_CACHE_DIR = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
_CACHE_DIR = os.path.join(
    _USER_CACHE_DIR, "yet-another-onnxruntime-extensions", "ci_durations_workflows"
)

# Workflows that are NOT CI (skip documentation / style / spelling workflows)
_SKIP_PATTERNS = ("docs", "style", "spelling", "pyrefly", "mypy", "doc_")


def _gh_get(path, params=""):
    """Performs a GET request to the GitHub REST API and returns parsed JSON.

    Returns:
        dict | None: Parsed JSON response, or None on any network or decode error.
    """
    url = f"{_API_BASE}/{path}"
    if params:
        url = f"{url}?{params}"
    req = urllib.request.Request(url, headers=_HEADERS)
    try:
        with urllib.request.urlopen(req, timeout=15) as resp:
            return json.loads(resp.read().decode())
    except (urllib.error.URLError, urllib.error.HTTPError, TimeoutError, ValueError, OSError):
        return None


def _fetch_workflow_runs(workflow_id, since_iso):
    """Fetches all workflow runs for *workflow_id* created on or after *since_iso*.

    Returns:
        list[dict]: List of workflow run dictionaries from the GitHub API.
    """
    runs = []
    page = 1
    while True:
        data = _gh_get(
            f"actions/workflows/{workflow_id}/runs",
            f"per_page=100&page={page}&created=>={since_iso}&branch=main",
        )
        if data is None or not data.get("workflow_runs"):
            break
        runs.extend(data["workflow_runs"])
        if len(data["workflow_runs"]) < 100:
            break
        page += 1
    return runs


def _cache_path(workflow_id):
    """Returns the cache file path for one workflow.

    Returns:
        str: Absolute path to the CSV cache file for the given workflow ID.
    """
    return os.path.join(_CACHE_DIR, f"{workflow_id}.csv")


def _load_cached_runs(workflow_id):
    """Loads cached workflow runs for one workflow.

    Returns:
        list[dict[str, str]]: Cached run rows.
    """
    cache_path = _cache_path(workflow_id)
    if not os.path.exists(cache_path):
        return []
    with open(cache_path, "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        rows = []
        for row in reader:
            rows.append(
                {
                    "created_at": row.get("created_at", ""),
                    "updated_at": row.get("updated_at", ""),
                    "status": row.get("status", ""),
                    "conclusion": row.get("conclusion", ""),
                }
            )
        return rows


def _save_cached_runs(workflow_id, runs):
    """Saves cached workflow runs for one workflow."""
    os.makedirs(_CACHE_DIR, exist_ok=True)
    cache_path = _cache_path(workflow_id)
    with open(cache_path, "w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(
            f, fieldnames=["created_at", "updated_at", "status", "conclusion"]
        )
        writer.writeheader()
        for run in runs:
            writer.writerow(
                {
                    "created_at": run.get("created_at", ""),
                    "updated_at": run.get("updated_at", ""),
                    "status": run.get("status", ""),
                    "conclusion": run.get("conclusion", ""),
                }
            )


def _cache_is_recent(workflow_id, now):
    """Checks whether a workflow cache file is newer than the cache max age.

    Returns:
        bool: True when the cache entry is recent enough.
    """
    cache_path = _cache_path(workflow_id)
    if not os.path.exists(cache_path):
        return False
    try:
        fetched_ts = os.path.getmtime(cache_path)
        fetched_dt = datetime.datetime.fromtimestamp(fetched_ts, tz=datetime.timezone.utc)
    except OSError:
        return False
    age = now - fetched_dt
    return age <= datetime.timedelta(days=_CACHE_MAX_AGE_DAYS)


def _run_duration_minutes(run):
    """Computes the wall-clock duration of a successfully completed run.

    Returns:
        float | None: Duration in minutes, or None when unavailable.
    """
    if run.get("status") != "completed":
        return None
    if run.get("conclusion") != "success":
        return None
    created = run.get("created_at")
    updated = run.get("updated_at")
    if not created or not updated:
        return None
    try:
        fmt = "%Y-%m-%dT%H:%M:%SZ"
        t0 = datetime.datetime.strptime(created, fmt)
        t1 = datetime.datetime.strptime(updated, fmt)
        minutes = (t1 - t0).total_seconds() / 60.0
        return round(minutes, 2) if minutes >= 0 else None
    except ValueError:
        return None


def _collect_data():
    """Collects workflow duration data from cache and/or GitHub API.

    Returns:
        dict: Mapping ``workflow_name -> list[(datetime, duration_min)]``.
    """
    now = datetime.datetime.now(datetime.timezone.utc)
    cutoff = (
        now - datetime.timedelta(days=62)
    ).strftime("%Y-%m-%d")
    workflows_data = _gh_get("actions/workflows")
    if not workflows_data:
        return {}

    result = {}
    for wf in workflows_data.get("workflows", []):
        name = wf.get("name", "")
        wf_id = wf.get("id")
        path = wf.get("path", "")
        # Skip non-CI workflows
        filename = path.split("/")[-1].lower() if "/" in path else path.lower()
        if any(filename.startswith(p) for p in _SKIP_PATTERNS):
            continue

        cache_key = str(wf_id)
        cached_runs = _load_cached_runs(cache_key)
        if _cache_is_recent(cache_key, now):
            runs = cached_runs
        else:
            fetched = _fetch_workflow_runs(wf_id, cutoff)
            if fetched:
                runs = fetched
                _save_cached_runs(cache_key, runs)
            else:
                runs = cached_runs
        points = []
        for run in runs:
            dur = _run_duration_minutes(run)
            if dur is None:
                continue
            try:
                fmt = "%Y-%m-%dT%H:%M:%SZ"
                dt = datetime.datetime.strptime(run["created_at"], fmt)
                points.append((dt, dur))
            except (KeyError, ValueError):
                continue
        if points:
            points.sort(key=lambda x: x[0])
            result[name] = points

    return result


_PLOT_CUTOFF = datetime.datetime(2025, 1, 1)
_AVG_WINDOW = 10

data = _collect_data()

if not data:
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.text(
        0.5, 0.5,
        "No data available (GitHub API not reachable or no completed runs found).",
        ha="center", va="center", transform=ax.transAxes, fontsize=11,
    )
    ax.axis("off")
    plt.tight_layout()
else:
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    for idx, (wf_name, points) in enumerate(sorted(data.items())):
        fig, ax = plt.subplots(figsize=(10, 4))
        filtered = [(d, v) for d, v in points if d >= _PLOT_CUTOFF]
        dates = [p[0] for p in filtered]
        durations = [p[1] for p in filtered]

        color = colors[idx % len(colors)]
        ax.plot(dates, durations, "-", color=color, linewidth=1.2, label=wf_name)

        # Rolling average (window = 10)
        if len(durations) >= _AVG_WINDOW:
            kernel = np.ones(_AVG_WINDOW) / _AVG_WINDOW
            avg = np.convolve(durations, kernel, mode="valid")
            offset = (_AVG_WINDOW - 1) // 2
            avg_dates = dates[offset: offset + len(avg)]
            ax.plot(avg_dates, avg, "--", color="orange", linewidth=1.5, alpha=0.9,
                    label=f"{_AVG_WINDOW}-run avg")

        ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d"))
        ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
        ax.tick_params(axis="x", rotation=30, labelsize=8)
        ax.set_title(wf_name, fontsize=12)
        ax.set_ylabel("Duration (min)", fontsize=10)
        ax.grid(True, linestyle="--", alpha=0.4)
        if len(durations) >= _AVG_WINDOW:
            ax.legend(fontsize=9)
        fig.autofmt_xdate()
        plt.tight_layout()