#!/usr/bin/env python

# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple changelog generator

This tool takes list of release versions to generate `CHANGELOG.md`.
The changelog will include all merged PRs w/o `[bot]` postfix,
and issues that include the labels `bug`, `feature request`, `SQL`, `performance`, `shuffle`,
minus any issues with the labels `wontfix`, `invalid` or `duplicate`.

For each project there should be an issue subsection for,
Features: all issues with label `feature request` + `SQL`
Performance: all issues with label `performance` + `shuffle`
Bugs fixed: all issues with label `bug`

To deduplicate section, the priority should be `Bugs fixed > Performance > Features`

NOTE: This is a repo-specific script, so you may not use it in other places.
Release version from arguments will be used to map project name and branch name.
Mapping Pattern:
release -> project: Release {release}, branch: branch-{release}
e.g.
    0.1 -> project: Release 0.1, branch: branch-0.1
    0.2 -> project: Release 0.2, branch: branch-0.2

Dependencies:
    - requests

Github personal access token: https://github.com/settings/tokens, and make you have `repo` and `project` scope selected

Usage:
    cd spark-rapids/

    # generate changelog for releases X, Y, Z
    scripts/generate-changelog --token=<GITHUB_PERSONAL_ACCESS_TOKEN> \
    --releases=X,Y,Z

    # To a separate file like /tmp/CHANGELOG.md
    GITHUB_TOKEN=<GITHUB_PERSONAL_ACCESS_TOKEN> scripts/generate-changelog \
    --releases=X,Y,Z \
    --path=/tmp/CHANGELOG.md
"""
import os
import sys
import subprocess
from argparse import ArgumentParser
from collections import OrderedDict
from datetime import date, datetime
from dateutil.relativedelta import relativedelta

import requests

# Constants
RELEASE = "Release"
PULL_REQUESTS = "pullRequests"
ISSUES = "issues"
# Subtitles
INVALID = 'Invalid'
BUGS_FIXED = 'Bugs Fixed'
PERFORMANCE = 'Performance'
FEATURES = 'Features'
PRS = 'PRs'
# Labels
LABEL_WONTFIX, LABEL_INVALID, LABEL_DUPLICATE = 'wontfix', 'invalid', 'duplicate'
LABEL_BUG = 'bug'
LABEL_PERFORMANCE, LABEL_SHUFFLE = 'performance', 'shuffle'
LABEL_FEATURE, LABEL_SQL = 'feature request', 'SQL'
# The release version from which the release branch changes (e.g., branch-YY.MM --> release/YY.MM)
FROM_RELEASE = '25.12'
# Queries
query_pr = """
query ($baseRefName: String!, $after: String) {
  repository(name: "spark-rapids", owner: "NVIDIA") {
    pullRequests(states: [MERGED], baseRefName: $baseRefName, first: 100, after: $after) {
      totalCount
      nodes {
        number
        title
        headRefName
        baseRefName
        state
        url
        labels(first: 10) {
          nodes {
            name
          }
        }
        projectItems(first: 10) {
          nodes {
              roadmap: fieldValueByName(name: "Roadmap") {
                ... on ProjectV2ItemFieldSingleSelectValue {
                    name
                }
              }
          }
        }
        mergedAt
      }
      pageInfo {
        hasNextPage
        endCursor
      }
    }
  }
}
"""
query_issue = """
query ($after: String, $since: DateTime) {
  repository(name: "spark-rapids", owner: "NVIDIA") {
    issues(states: [CLOSED], labels: ["SQL", "feature request", "performance", "bug", "shuffle"], first: 100, after: $after, filterBy: {since: $since}) {
      totalCount
      nodes {
        number
        title
        state
        url
        labels(first: 10) {
          nodes {
            name
          }
        }
        projectItems(first: 10) {
          nodes {
            roadmap: fieldValueByName(name: "Roadmap") {
              ... on ProjectV2ItemFieldSingleSelectValue {
                name
              }
            }
          }
        }
        closedAt
      }
      pageInfo {
        hasNextPage
        endCursor
      }
    }
  }
}
"""
query_pr_by_commit = """
query ($sha: String!) {
  repository(name: "spark-rapids", owner: "NVIDIA") {
    commit: object(expression: $sha) {
      ... on Commit {
        associatedPullRequests(first: 10) {
          edges {
            node {
              title
              number
              state
              url
              baseRefName
              labels(first: 10) {
                nodes {
                  name
                }
              }
              mergedAt
              projectItems(first: 10) {
                nodes {
                  roadmap: fieldValueByName(name: "Roadmap") {
                    ... on ProjectV2ItemFieldSingleSelectValue {
                      name
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}
"""

# Get the previous release version string(YY.MM format, 2 months before the current version)
#   from the current version, e.g. YY.MM2[current] --> YY.MM1[previous]
# param current_ver: the current version, e.g. YY.MM2
# return: the previous version, e.g. YY.MM1,
def get_prev_release_version(current_ver: str):
    year, month = map(int, current_ver.split("."))
    if month > 2:
        new_year = year
        new_month = month - 2
    else:
        new_year = year - 1
        new_month = month + 10
    prev_ver = f"{new_year:02d}.{new_month:02d}"
    return prev_ver

# Get the commit hashes between two branches or release tags.
# param releases: set of release versions, e.g. {'YY.MM2', 'YY.MM1'}
# return: dict of commit hashes, e.g. {YY.MM2: [sha1, sha2, ...], YY.MM1: [shaX, shaY, ...]}
def get_commits(releases: set):
    rel_list = list(releases)
    ver_commits = {}
    count = len(rel_list)  # descending version order assured
    for i, to_rel in enumerate(rel_list):
        to_branch = f"origin/release/{to_rel}"
        # commits of releases[YY.MM2, YY.MM1] --> git log "YY.MM2..YY.MM1" for YY.MM2, "YY.MM1..YY.MM0" for YY.MM1
        if i + 1 < count:
            from_rel = rel_list[i + 1]
        else:
            from_rel = get_prev_release_version(to_rel)
        based_rel = float(from_rel)
        if based_rel < float(FROM_RELEASE):
            from_branch = f"origin/branch-{from_rel}"
        else:
            from_branch = f"origin/release/{from_rel}"

        # Get all the commit hashes, excluding those commits whose title contains '[bot]'
        git_log_args = [
            "git", "--no-pager", "log",
            f"{from_branch}..{to_branch}", "--pretty=format:%h",
            "--grep=[bot]", "-F", "--invert-grep"
        ]

        # Use check=True to raise exception if git fails, making errors explicit
        result = subprocess.run(git_log_args, capture_output=True, text=True, check=True)

        commits = result.stdout.splitlines()
        ver_commits[to_rel] = commits
    return ver_commits

# Get the PR list from commit hashes
# param ver_commits, e.g. {v1: [sha1, sha2, ...], v2: [shaX, shaY, ...]}
# param token: the token for the API
# return: list of PRs associated with the commit hashes, e.g. [{PR1 info}, {PR2 info}, ...]
def get_pr_via_commits(ver_commits: set, token: str):
    pr_list = []
    for version, commits in ver_commits.items():
        for sha in commits:
            res = post(query=query_pr_by_commit, token=token, variable={'sha': sha})
            try:
                pr_item = res.json()['data']['repository']['commit']['associatedPullRequests']['edges'][0]['node']
                pr_item['ver'] = version
                # Handle the case of multiple commits being associated with the same PR
                if pr_item not in pr_list and pr_item['mergedAt'] is not None:
                    pr_list.append(pr_item)
            except Exception as e:
                print(f"Exception: {e}, commit sha '{sha}' does not have the associated Pull Request")
                continue
    return pr_list

def process_changelog(resource_type: str, changelog: dict, releases: set, projects: set, token: str):
    if resource_type == PULL_REQUESTS:
        items = process_pr(releases=releases, token=token)
        time_field = 'mergedAt'
    elif resource_type == ISSUES:
        items = process_issue(releases=releases, token=token)
        time_field = 'closedAt'
    else:
        print(f"[process_changelog] Invalid type: {resource_type}")
        sys.exit(1)

    for item in items:
        if len(item["projectItems"]["nodes"]) == 0 or not item["projectItems"]["nodes"][0]['roadmap']:
            if resource_type == PULL_REQUESTS:
                if '[bot]' in item['title']:
                    continue  # skip auto-gen PR
                # Obtain the version from the PR's target branch, e.g. branch-x.y --> x.y
                ver = item['baseRefName'].replace('branch-', '')
                project = f"{RELEASE} {ver}"
            else:
                continue
        else:
            ver = item["projectItems"]["nodes"][0]['roadmap']['name']
            project = f"{RELEASE} {ver}"

        # Overwrite project version after the {FROM_RELEASE} if provided
        if item.get('ver') is not None:
            ver = item['ver']
            project = f"{RELEASE} {ver}"

        if not release_project(project, projects):
            continue

        if project not in changelog:
            changelog[project] = {
                FEATURES: [],
                PERFORMANCE: [],
                BUGS_FIXED: [],
                PRS: [],
            }

        labels = set()
        for label in item['labels']['nodes']:
            labels.add(label['name'])
        category = rules(labels)
        if resource_type == ISSUES and category == INVALID:
            continue
        if resource_type == PULL_REQUESTS:
            if '[bot]' in item['title']:  # skip auto-gen PR
                continue
            if '[databricks]' in item['title']: # strip ambiguous CI annotation
                item['title'] = item['title'].replace('[databricks]', '').strip()
            category = PRS

        changelog[project][category].append({
            "number": item['number'],
            "title": item['title'],
            "url": item['url'],
            "time": item[time_field],
        })


# Get the PRs based on the release versions
def process_pr(releases: set, token: str):
    pr = []
    current_ver = list(releases)[0]
    current_ver_float = float(current_ver)
    based_rel = float(FROM_RELEASE)

    # Note: only the last 2 releases are supported/included in the changelog
    # Both releases are after {FROM_RELEASE}
    if current_ver_float > based_rel:
        ver_commits = get_commits(releases)
        pr = get_pr_via_commits(ver_commits, token)
    # One release is the {FROM_RELEASE}, the other is before the {FROM_RELEASE}
    elif current_ver_float == based_rel:
        ver_commits = get_commits({FROM_RELEASE})
        pr = get_pr_via_commits(ver_commits, token)
        prev_ver = get_prev_release_version(current_ver=FROM_RELEASE)
        pr.extend(fetch(resource_type=PULL_REQUESTS, token=token,
                        variables={'baseRefName': f"branch-{prev_ver}"}))
    # Both releases are before the {FROM_RELEASE}
    else:
        for rel in releases:
            pr.extend(fetch(resource_type=PULL_REQUESTS, token=token,
                            variables={'baseRefName': f"branch-{rel}"}))
    return pr


def process_issue(releases: set, token: str):
    since = (datetime.utcnow() - relativedelta(months=3 * len(releases))).isoformat()
    return fetch(resource_type=ISSUES, token=token, variables={'since': since})


def fetch(resource_type: str, token: str, variables: dict = None):
    items = []
    if resource_type == PULL_REQUESTS and variables:
        q = query_pr
    elif resource_type == ISSUES:
        q = query_issue
    else:
        return items

    has_next = True
    while has_next:
        res = post(query=q, token=token, variable=variables)
        if res.status_code == 200:
            d = res.json()
            has_next = d['data']['repository'][resource_type]["pageInfo"]["hasNextPage"]
            variables['after'] = d['data']['repository'][resource_type]["pageInfo"]["endCursor"]
            items.extend(d['data']['repository'][resource_type]['nodes'])
        else:
            raise Exception("Query failed to run by returning code of {}. {}".format(res.status_code, q))
    return items


def post(query: str, token: str, variable: dict):
    return requests.post('https://api.github.com/graphql',
                         json={'query': query, 'variables': variable},
                         headers={"Authorization": f"token {token}"})


def release_project(project_name: str, projects: set):
    if project_name in projects:
        return True
    return False


def rules(labels: set):
    if LABEL_WONTFIX in labels or LABEL_INVALID in labels or LABEL_DUPLICATE in labels:
        return INVALID
    if LABEL_BUG in labels:
        return BUGS_FIXED
    if LABEL_PERFORMANCE in labels or LABEL_SHUFFLE in labels:
        return PERFORMANCE
    if LABEL_FEATURE in labels or LABEL_SQL in labels:
        return FEATURES
    return INVALID


def form_changelog(path: str, changelog: dict):
    sorted_dict = OrderedDict(sorted(changelog.items(), reverse=True))
    subsections = ""
    for project_name, issues in sorted_dict.items():
        subsections += f"\n\n## {project_name}"
        subsections += form_subsection(issues, FEATURES)
        subsections += form_subsection(issues, PERFORMANCE)
        subsections += form_subsection(issues, BUGS_FIXED)
        subsections += form_subsection(issues, PRS)
    markdown = f"""# Change log
Generated on {date.today()}{subsections}
\n## Older Releases
Changelog of older releases can be found at [docs/archives](/docs/archives)
"""
    with open(path, "w") as file:
        file.write(markdown)


def form_subsection(issues: dict, subtitle: str):
    if len(issues[subtitle]) == 0:
        return ''
    subsection = f"\n\n### {subtitle}"
    subsection += "\n|||\n|:---|:---|"
    for issue in sorted(issues[subtitle], key=lambda x: x['time'], reverse=True):
        subsection += f"\n|[#{issue['number']}]({issue['url']})|{issue['title']}|"
    return subsection



def main(rels: str, path: str, token: str):
    print('Generating changelog ...')

    try:
        changelog = {}  # changelog dict
        releases = {x.strip() for x in rels.split(',')}
        # Sort releases in descending order for the follow-up operations
        releases = sorted(releases, reverse=True)
        projects = {f"{RELEASE} {rel}" for rel in releases}

        print('Processing pull requests ...')
        process_changelog(resource_type=PULL_REQUESTS, changelog=changelog,
                          releases=releases, projects=projects, token=token)
        print('Processing issues ...')
        process_changelog(resource_type=ISSUES, changelog=changelog,
                          releases=releases, projects=projects, token=token)
        # form doc
        form_changelog(path=path, changelog=changelog)
    except Exception as e:  # pylint: disable=broad-except
        print(e)
        sys.exit(1)

    print('Done.')


if __name__ == '__main__':
    parser = ArgumentParser(description="Changelog Generator")
    parser.add_argument("--releases", help="list of release versions, separated by comma",
                        default="0.1,0.2,0.3")
    parser.add_argument("--token", help="github token, will use GITHUB_TOKEN if empty", default='')
    parser.add_argument("--path", help="path for generated changelog file", default='./CHANGELOG.md')
    args = parser.parse_args()

    github_token = args.token if args.token else os.environ.get('GITHUB_TOKEN')
    assert github_token, 'env GITHUB_TOKEN should not be empty'

    main(rels=args.releases, path=args.path, token=github_token)
