# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Applying Gemini CLI to Fix Chromium Unsafe Buffer Usage

This is a script to discover, categorize and generate spanification fixes for
given file.
"""

import subprocess
import contextlib
import json
import os
import sys
import argparse
import time
import re

GEMINI_MD_PATH = 'GEMINI.md'  # Assuming the script is run from src
SCRIPT_DIR = os.path.dirname(__file__)
CATEGORIZE_PROMPT_MD = os.path.join(SCRIPT_DIR, 'prompt_categorize.md')
FIXING_PROMPT_MD = os.path.join(SCRIPT_DIR, 'prompt_fixing.md')
SPANIFICATION_GEMINI_MD = 'SPANIFICATION_GEMINI_MD'


def discover_unsafe_todos(folder=None):
    cmd = [
        'git', 'grep', '-l', '-e', 'UNSAFE_TODO', '--or', '-e',
        'allow_unsafe_buffers'
    ]
    if folder:
        cmd.append(folder)

    result = subprocess.run(cmd, capture_output=True, text=True, check=False)
    if result.returncode != 0:
        # git grep returns 1 if no lines are selected, which is not an error.
        if result.returncode == 1 and result.stdout == "":
            return []
        print("Error discovering files:", result.stderr)
        return []
    return result.stdout.strip().split('\n')


@contextlib.contextmanager
def setup_gemini_context_md(file_path):
    """Context manager to temporarily modify GEMINI.md to include the given
    prompt file."""

    def modify_gemini_md(action, new_entry=""):
        content = ""
        if os.path.exists(GEMINI_MD_PATH):
            with open(GEMINI_MD_PATH, 'r', encoding='utf-8') as f:
                content = f.read()
        else:
            print("Error: the script is expected to be run from the src/ "
                  "directory where GEMINI.md is located.")
            sys.exit(1)

        # Use regex to remove the block between the start and end markers.
        # re.DOTALL allows '.' to match newlines.
        pattern = re.compile(
            f"# {SPANIFICATION_GEMINI_MD}.*?"
            f"# /{SPANIFICATION_GEMINI_MD}\n", re.DOTALL)
        cleaned_content = pattern.sub("", content)

        final_content = cleaned_content
        if action == 'add':
            final_content += new_entry

        with open(GEMINI_MD_PATH, 'w', encoding='utf-8') as f:
            f.write(final_content)

    entry = (f"# {SPANIFICATION_GEMINI_MD}\n"
             f"@{file_path}\n# /{SPANIFICATION_GEMINI_MD}\n")
    modify_gemini_md('add', entry)
    try:
        yield
    finally:
        modify_gemini_md('remove')


GEMINI_OUT_DIR = 'gemini_out'

# The structure of summary.json is expected to be:
# {
#   "status": "SUCCESS" | "FAILURE" | "TIMEOUT" | "GEMINI_FAILURE",
#   "summary": "A summary of the changes made.",
#   "variable_type": "The categorized variable type.",
#   "access_type": "The categorized access type.",
#   ... and other fields generated by the gemini tool ...
#   "task_prompt": "The full prompt passed to the gemini tool.",
#   "task_args": ["The file path and other arguments for the task."],
#   "exit_code": "The exit code of the gemini tool.",
#   "duration": "The duration of the gemini tool execution in seconds."
#   "session_summary": { ... raw output from session-summary.json ... }
# }
SUMMARY_JSON_PATH = GEMINI_OUT_DIR + '/summary.json'
COMMIT_MESSAGE_PATH = GEMINI_OUT_DIR + '/commit_message.md'
SESSION_SUMMARY_PATH = GEMINI_OUT_DIR + '/session-summary.json'


def run_gemini(prompt, task_args, interactive=False, yolo=False):
    """
    Run the gemini CLI with the given prompt and task arguments.
    Returns the parsed summary.json content.
    """
    ALLOWED_TOOLS = [
        # read tools were always allowed
        "write_file",
        "replace",
        "ShellTool(rg)",
        "ShellTool(fdfind)",
        "ShellTool(cs)",
        "ShellTool(vpython3 tools/utr)",
    ]

    if not os.path.exists('gemini_out'):
        os.makedirs('gemini_out')
    # Clean up previous run files
    for f in [SUMMARY_JSON_PATH, COMMIT_MESSAGE_PATH, SESSION_SUMMARY_PATH]:
        if os.path.exists(f):
            os.remove(f)

    cmd = [
        'gemini',
        '--session-summary',
        SESSION_SUMMARY_PATH,
    ]
    if yolo:
        cmd.append('--yolo')
    else:
        cmd.extend([
            '--approval-mode', 'auto_edit', '--allowed-tools', *ALLOWED_TOOLS
        ])

    cmd.extend(['-i' if interactive else '-p', prompt])

    start_time = time.time()
    gemini_exit_code = 0
    try:
        # Using a timeout of 3000 seconds as in fix.sh
        timeout = None if interactive else 3000
        result = subprocess.run(cmd, timeout=timeout, check=False)
        gemini_exit_code = result.returncode
    except subprocess.TimeoutExpired:
        gemini_exit_code = 124  # timeout exit code in linux
    end_time = time.time()
    duration = int(end_time - start_time)

    final_summary = {}
    if os.path.exists(SUMMARY_JSON_PATH):
        try:
            with open(SUMMARY_JSON_PATH, 'r', encoding='utf-8') as f:
                final_summary = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error reading {SUMMARY_JSON_PATH}: {e}")
            final_summary = {'status': 'JSON_ERROR'}
    else:
        print(
            f"Error: {SUMMARY_JSON_PATH} not found by gemini. Creating a "
            "fallback."
        )
        if gemini_exit_code == 124:
            final_summary = {'status': 'TIMEOUT'}
        # TODO: Add check for MAX_TURNS_REACHED by inspecting gemini.log
        else:
            final_summary = {'status': 'GEMINI_FAILURE'}

    # Enrich the summary
    final_summary['task_prompt'] = prompt
    final_summary['task_args'] = task_args
    final_summary['exit_code'] = gemini_exit_code
    final_summary['duration'] = duration

    if os.path.exists(SESSION_SUMMARY_PATH):
        try:
            with open(SESSION_SUMMARY_PATH, 'r', encoding='utf-8') as f:
                final_summary['session_summary'] = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error reading {SESSION_SUMMARY_PATH}: {e}")

    return final_summary


def categorize_file(file_path, yolo=False):
    """Categorize the unsafe buffer usage in the given file."""
    with setup_gemini_context_md(CATEGORIZE_PROMPT_MD):
        prompt = (
            f"detect the unsafe access and variable category for {file_path}")
        return run_gemini(prompt, task_args=[file_path], yolo=yolo)


def generate_fix(file_path,
                 variable_type=None,
                 access_type=None,
                 interactive=False,
                 yolo=False):
    """Generate spanification fix for the given file based on its categories."""

    VARIABLE_PROMPTS = {
        'Already-Safe': 'No changes to the variable are needed.',
        'Local-Variable': 'Arrayify the variable using `std::to_array`.',
        'Local-Method-Argument': (
            'Change the method signature to take a `std::span`.'
        ),
        'Class-Method-with-Safe-Variant': (
            'Replace the unsafe methods (that return a buffer) '
            'with a safe variant.'
        ),
        'Method-Argument': (
            'Change the method signature to take a `std::span` '
            'and update all call sites.'
        ),
        'Global-Variable': (
            'Arrayify the variable using `std::to_array` '
            'and update all usages.'
        ),
        'Class-Method-Safe-Variant-TODO': (
            'Migrate the internal members to safe containers '
            'and create a new safe method variant.'
        ),
    }

    ACCESS_PROMPTS = {
        'operator[]': (
            'The access should be safe now, '
            'just remove the `UNSAFE_TODO`.'
        ),
        'Pointer-Arithmetic': (
            'Use `base::span::first(N)`, `base::span::subspan(offset, count)` '
            '... instead.'
        ),
        'Safe-Container-Construction': (
            'Convert `base::span(pointer, size)` to a safe constructor '
            'like `base::span(container)`. If the size changed, you could use '
            '`base::span(other_span).subspan(...)` or `first(...)` to create '
            'safe views into existing spans.'
        ),
        'std::memcmp': 'Replace the comparison with `operator==`.',
        'std::strcmp': 'Replace the comparison with `operator==`.',
        'std::memcpy': 'Replace the copy with `base::span::copy_from()`.',
        'std::strncpy': 'Replace the copy with `base::span::copy_from()`.',
        'std::strcpy': 'Replace the copy with `base::span::copy_from()`.',
        'std::memset': (
            'Replace memset with `std::ranges::fill()` or `<instance> = {}`.'
        ),
        'std::strstr': 'Replace the search with `std::string_view::find()`.',
        'std::wcslen': 'Just get size() from the safe container.',
        'std::strlen': 'Just get size() from the safe container.',
    }

    variable_prompt = VARIABLE_PROMPTS.get(variable_type, '')
    access_prompt = ACCESS_PROMPTS.get(access_type, '')
    task_args = [file_path, variable_type, access_type]

    if not variable_prompt and not access_prompt:
        generated_prompt = f"Fix the unsafe buffer usage in {file_path}."
    elif not variable_prompt:
        print(f"Warning: Unknown variable_type ('{variable_type}').")
        return {
            'status': 'NOT_SUPPORTED',
            'summary': f"Unknown variable_type: {variable_type}",
            'task_args': task_args,
            'duration': 0,
        }
    elif not access_prompt:
        print(f"Warning: Unknown access_type ('{access_type}').")
        return {
            'status': 'NOT_SUPPORTED',
            'summary': f"Unknown access_type: {access_type}",
            'task_args': task_args,
            'duration': 0,
        }
    else:
        generated_prompt = (
            f"The variable in {file_path} is of type {variable_type}. "
            f"{variable_prompt} The unsafe access pattern is {access_type}. "
            f"{access_prompt} ")

    with setup_gemini_context_md(FIXING_PROMPT_MD):
        return run_gemini(generated_prompt, task_args, interactive, yolo=yolo)


def autocommit_changes(fix_result, file_path):
    """Automatically commit changes if the fix was successful,
    otherwise reset."""
    is_success = fix_result.get('status') == 'SUCCESS'

    if is_success:
        print(f"Successfully fixed {file_path}. Committing changes.")
        if os.path.exists(COMMIT_MESSAGE_PATH):
            subprocess.run(['git', 'commit', '-a', '-F', COMMIT_MESSAGE_PATH],
                           check=True)
            print(f"Committed fix for {file_path}.")
        else:
            print(f"Warning: {COMMIT_MESSAGE_PATH} not found. Cannot commit. "
                  "Resetting to HEAD.")
            subprocess.run(['git', 'reset', '--hard', 'HEAD'], check=True)
    else:
        print(f"Fix generation failed for {file_path}. Resetting to HEAD.")
        subprocess.run(['git', 'reset', '--hard', 'HEAD'], check=True)


def main():
    parser = argparse.ArgumentParser(
        description='Discover, categorize and generate spanification fixes.')
    parser.add_argument('path',
                        nargs='?',
                        default=None,
                        help='The file or folder to process.')
    parser.add_argument('--categorize-only',
                        action='store_true',
                        help='Only run the categorization step.')
    parser.add_argument('--fix-only',
                        action='store_true',
                        help='Only run the fix generation step.')
    parser.add_argument('-y',
                        '--yolo',
                        action='store_true',
                        help='Enable YOLO mode for Gemini CLI.')
    parser.add_argument('-i',
                        '--interactive',
                        action='store_true',
                        help='Run in interactive mode for fixing.')
    parser.add_argument('--autocommit',
                        action='store_true',
                        help='Automatically commit successful fixes.')
    args = parser.parse_args()

    files_to_process = []
    if args.path:
        if os.path.isdir(args.path):
            files_to_process = discover_unsafe_todos(args.path)
        else:
            files_to_process.append(args.path)
    else:
        files_to_process = discover_unsafe_todos()

    if not files_to_process:
        print("No files to process.")
        return

    for file_path in files_to_process:
        print(f"Processing {file_path}...")
        categorization_result = None
        if not args.fix_only:
            categorization_result = categorize_file(file_path, yolo=args.yolo)
            print("Categorization result:", json.dumps(categorization_result))

        if not args.categorize_only:
            variable_type = None
            access_type = None

            if categorization_result:
                if categorization_result.get('status') != 'SUCCESS':
                    print(f"Skipping fix generation for {file_path} due to "
                          "categorization failure.")
                    continue
                # The categorization step is expected to return 'variable_type'
                # and 'access_type' in summary.json
                variable_type = categorization_result.get('variable_type')
                access_type = categorization_result.get('access_type')

            fix_result = generate_fix(file_path,
                                      variable_type,
                                      access_type,
                                      args.interactive,
                                      yolo=args.yolo)
            print("Fix generation result:", json.dumps(fix_result))

            if args.autocommit and not args.interactive:
                autocommit_changes(fix_result, file_path)


if __name__ == '__main__':
    main()
