Files
blackroad-operating-system/operator_engine/pr_actions/handlers/merge_pr.py
2025-11-17 23:38:25 -06:00

142 lines
4.5 KiB
Python

"""
Merge PR Handler
Handles merging PRs with various strategies.
"""
from typing import Dict, Any
import logging
from . import BaseHandler
from ..action_types import PRAction, PRActionType
logger = logging.getLogger(__name__)
class MergePRHandler(BaseHandler):
"""Handler for merging PRs"""
async def execute(self, action: PRAction) -> Dict[str, Any]:
"""
Merge a PR.
Expected params:
- merge_method: "merge", "squash", or "rebase" (default: from action_type)
- commit_title: Optional custom commit title
- commit_message: Optional custom commit message
- skip_checks: If True, merge without waiting for checks (default: False)
"""
gh = await self.get_github_client()
# Determine merge method
merge_method = action.params.get("merge_method")
if not merge_method:
if action.action_type == PRActionType.SQUASH_MERGE:
merge_method = "squash"
elif action.action_type == PRActionType.REBASE_MERGE:
merge_method = "rebase"
else:
merge_method = "merge"
# Get the PR
pr = await gh.get_pull_request(
action.repo_owner, action.repo_name, action.pr_number
)
if not pr:
raise ValueError(f"PR #{action.pr_number} not found")
# Check if PR is mergeable
if not pr.get("mergeable", False):
raise ValueError(
f"PR #{action.pr_number} is not mergeable. "
f"Merge state: {pr.get('mergeable_state')}"
)
# Check if checks are passing (unless skip_checks is True)
skip_checks = action.params.get("skip_checks", False)
if not skip_checks:
checks_passing = await self._check_required_checks(gh, action)
if not checks_passing:
raise ValueError(
f"Required checks are not passing for PR #{action.pr_number}"
)
# Merge the PR
result = await gh.merge_pull_request(
action.repo_owner,
action.repo_name,
action.pr_number,
merge_method=merge_method,
commit_title=action.params.get("commit_title"),
commit_message=action.params.get("commit_message"),
)
logger.info(
f"Merged PR #{action.pr_number} using {merge_method} method. "
f"Merge commit: {result.get('sha')}"
)
return {
"pr_number": action.pr_number,
"merged": True,
"merge_method": merge_method,
"sha": result.get("sha"),
"message": result.get("message"),
}
async def _check_required_checks(self, gh, action: PRAction) -> bool:
"""Check if all required checks are passing"""
pr = await gh.get_pull_request(
action.repo_owner, action.repo_name, action.pr_number
)
head_sha = pr["head"]["sha"]
# Get check runs
check_runs = await gh.get_check_runs(
action.repo_owner, action.repo_name, head_sha
)
# Get required checks for the repo
required_checks = await gh.get_required_checks(
action.repo_owner, action.repo_name, pr["base"]["ref"]
)
# Defensive checks for required_checks
if required_checks is None:
required_checks = []
elif not isinstance(required_checks, list):
logger.warning(f"Unexpected required_checks type: {type(required_checks)}")
required_checks = []
# If no required checks, consider it passing
if not required_checks:
return True
# Check if all required checks are passing
for required_check in required_checks:
matching_checks = [
check for check in check_runs
if check["name"] == required_check
]
if not matching_checks:
logger.warning(
f"Required check '{required_check}' not found for PR #{action.pr_number}"
)
return False
# Check if any matching check passed
passed = any(
check["conclusion"] == "success"
for check in matching_checks
)
if not passed:
logger.warning(
f"Required check '{required_check}' did not pass for PR #{action.pr_number}"
)
return False
return True