diff --git a/src/repairchain/strategies/generation/reversion.py b/src/repairchain/strategies/generation/reversion.py index d6f9198e..cb6d64cd 100644 --- a/src/repairchain/strategies/generation/reversion.py +++ b/src/repairchain/strategies/generation/reversion.py @@ -17,7 +17,7 @@ from repairchain.models.diff import Diff from repairchain.models.patch_outcome import PatchOutcome from repairchain.strategies.generation.base import PatchGenerationStrategy -from repairchain.util import dd_minimize +from repairchain.util import dd_minimize, revert_diff if t.TYPE_CHECKING: from sourcelocation.diff import FileHunk @@ -80,12 +80,7 @@ def _cleanup(self, restore_to: str) -> None: def _compute_reverse_diff(self) -> Diff: """Computes a diff that reverses the changes introduced by the triggering commit.""" - unidiff = self.project.repository.git.diff( - self.triggering_commit, - self.triggering_commit_parent, - unified=True, - ) - return Diff.from_unidiff(unidiff).strip(1) + return revert_diff(self.project.original_implicated_diff) def _minimize_reverse_diff(self, reverse_diff: Diff) -> Diff: """Minimizes the reverse diff to the smallest possible diff that still undoes the triggering commit.""" diff --git a/src/repairchain/util.py b/src/repairchain/util.py index 86624886..7b73db2d 100644 --- a/src/repairchain/util.py +++ b/src/repairchain/util.py @@ -10,10 +10,17 @@ from dockerblade.stopwatch import Stopwatch from loguru import logger +from sourcelocation.diff import ( + DeletedLine, + FileHunk, + HunkLine, + InsertedLine, +) from repairchain.models.diff import ( Diff, FileDiff, + Hunk, ) if t.TYPE_CHECKING: @@ -24,6 +31,34 @@ T = t.TypeVar("T") +def revert_diff(diff: Diff) -> Diff: + def revert_hunk_line(line: HunkLine) -> HunkLine: + match line: + case InsertedLine(content): + return DeletedLine(content) + case DeletedLine(content): + return InsertedLine(content) + return line + + def revert_hunk(hunk: Hunk) -> Hunk: + return Hunk( + old_start_at=hunk.old_start_at, + new_start_at=hunk.new_start_at, + lines=[revert_hunk_line(line) for line in hunk.lines], + ) + + def revert_file_hunk(file_hunk: FileHunk) -> FileHunk: + return FileHunk( + old_filename=file_hunk.old_filename, + new_filename=file_hunk.new_filename, + hunk=revert_hunk(file_hunk.hunk), + ) + + return Diff.from_file_hunks([ + revert_file_hunk(file_hunk) for file_hunk in list(diff.file_hunks) + ]) + + def statements_in_function( index: kaskara.analysis.Analysis, function: kaskara.functions.Function, diff --git a/test/integration/test_util.py b/test/integration/test_util.py index 4338729d..fff42bcb 100644 --- a/test/integration/test_util.py +++ b/test/integration/test_util.py @@ -1,6 +1,9 @@ import typing as t -from repairchain.util import statements_in_function +from repairchain.util import ( + revert_diff, + statements_in_function, +) if t.TYPE_CHECKING: @@ -9,6 +12,18 @@ from repairchain.indexer import KaskaraIndexer +def test_revert_diff( + example_project_factory, +) -> None: + with example_project_factory("mock-cp") as project: + repo = project.repository + original_implicated_diff = project.original_implicated_diff + triggering_commit = project.triggering_commit + triggering_commit_parent = triggering_commit.parents[0] + + assert revert_diff(revert_diff(original_implicated_diff)) == original_implicated_diff + + def test_statements_in_function( example_project_factory, ) -> None: