diff --git a/api/src/main/java/net/neoforged/jst/api/ImportHelper.java b/api/src/main/java/net/neoforged/jst/api/ImportHelper.java new file mode 100644 index 0000000..0a92b0d --- /dev/null +++ b/api/src/main/java/net/neoforged/jst/api/ImportHelper.java @@ -0,0 +1,150 @@ +package net.neoforged.jst.api; + +import com.intellij.psi.PsiClass; +import com.intellij.psi.PsiField; +import com.intellij.psi.PsiFile; +import com.intellij.psi.PsiImportStatementBase; +import com.intellij.psi.PsiImportStaticStatement; +import com.intellij.psi.PsiJavaFile; +import com.intellij.psi.PsiMethod; +import com.intellij.psi.PsiModifier; +import com.intellij.psi.PsiPackage; +import org.jetbrains.annotations.Nullable; +import org.jetbrains.annotations.VisibleForTesting; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Helper class used to import classes while processing a source file. + * @see ImportHelper#get(PsiFile) + */ +public class ImportHelper implements PostProcessReplacer { + private final PsiJavaFile psiFile; + private final Map importedNames = new HashMap<>(); + + private final Set successfulImports = new HashSet<>(); + + public ImportHelper(PsiJavaFile psiFile) { + this.psiFile = psiFile; + + if (psiFile.getPackageStatement() != null) { + var resolved = psiFile.getPackageStatement().getPackageReference().resolve(); + // We cannot import a class with the name of a class in the package of the file + if (resolved instanceof PsiPackage pkg) { + for (PsiClass cls : pkg.getClasses()) { + importedNames.put(cls.getName(), cls.getQualifiedName()); + } + } + } + + if (psiFile.getImportList() != null) { + for (PsiImportStatementBase stmt : psiFile.getImportList().getImportStatements()) { + var res = stmt.resolve(); + if (res instanceof PsiPackage pkg) { + // Wildcard package imports will reserve all names of top-level classes in the package + for (PsiClass cls : pkg.getClasses()) { + importedNames.put(cls.getName(), cls.getQualifiedName()); + } + } else if (res instanceof PsiClass cls) { + importedNames.put(cls.getName(), cls.getQualifiedName()); + } + } + + for (PsiImportStaticStatement stmt : psiFile.getImportList().getImportStaticStatements()) { + var res = stmt.resolve(); + if (res instanceof PsiMethod method) { + importedNames.put(method.getName(), method.getName()); + } else if (res instanceof PsiField fld) { + importedNames.put(fld.getName(), fld.getName()); + } else if (res instanceof PsiClass cls && stmt.isOnDemand()) { + // On-demand imports are static wildcard imports which will reserve the names of + // - all static methods available through the imported class + for (PsiMethod met : cls.getAllMethods()) { + if (met.getModifierList().hasModifierProperty(PsiModifier.STATIC)) { + importedNames.put(met.getName(), met.getName()); + } + } + + // - all fields available through the imported class + for (PsiField fld : cls.getAllFields()) { + if (fld.getModifierList() != null && fld.getModifierList().hasModifierProperty(PsiModifier.STATIC)) { + importedNames.put(fld.getName(), fld.getName()); + } + } + + // - all inner classes available through the imported class directly + for (PsiClass c : cls.getAllInnerClasses()) { + importedNames.put(c.getName(), c.getQualifiedName()); + } + + // Note: to avoid possible issues, none of the above check for visibility. We prefer to be more conservative to make sure the output sources compile + } + } + } + } + + @VisibleForTesting + public boolean canImport(String name) { + return !importedNames.containsKey(name); + } + + /** + * Attempts to import the given fully qualified class name, returning a reference to it which is either + * its short name (if an import is successful) or the qualified name if not. + */ + public String importClass(String cls) { + var clsByDot = cls.split("\\."); + // We do not try to import classes in the default package or classes already imported + if (clsByDot.length == 1 || successfulImports.contains(cls)) { + return clsByDot[clsByDot.length - 1]; + } + // We also do not want to import classes under java.lang.* + else if (clsByDot.length == 3 && clsByDot[0].equals("java") && clsByDot[1].equals("lang")) { + return clsByDot[2]; + } + + var name = clsByDot[clsByDot.length - 1]; + + if (Objects.equals(importedNames.get(name), cls)) { + return name; + } + + if (canImport(name)) { + successfulImports.add(cls); + return name; + } + + return cls; + } + + @Override + public void process(Replacements replacements) { + if (successfulImports.isEmpty()) return; + + var insertion = successfulImports.stream() + .sorted() + .map(s -> "import " + s + ";") + .collect(Collectors.joining("\n")); + + if (psiFile.getImportList() != null && psiFile.getImportList().getLastChild() != null) { + var lastImport = psiFile.getImportList().getLastChild(); + replacements.insertAfter(lastImport, "\n\n" + insertion); + } else { + replacements.insertBefore(psiFile.getClasses()[0], insertion + "\n\n"); + } + } + + @Nullable + public static ImportHelper get(PsiFile file) { + return file instanceof PsiJavaFile j ? get(j) : null; + } + + public static ImportHelper get(PsiJavaFile file) { + return PostProcessReplacer.getOrCreateReplacer(file, ImportHelper.class, k -> new ImportHelper(file)); + } +} diff --git a/api/src/main/java/net/neoforged/jst/api/PostProcessReplacer.java b/api/src/main/java/net/neoforged/jst/api/PostProcessReplacer.java new file mode 100644 index 0000000..26b5d73 --- /dev/null +++ b/api/src/main/java/net/neoforged/jst/api/PostProcessReplacer.java @@ -0,0 +1,38 @@ +package net.neoforged.jst.api; + +import com.intellij.openapi.util.Key; +import com.intellij.psi.PsiFile; +import org.jetbrains.annotations.UnmodifiableView; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.function.Function; + +/** + * A replacer linked to a {@link PsiFile} will run and collect replacements after all {@link SourceTransformer transformers} have processed the file. + */ +public interface PostProcessReplacer { + Key, PostProcessReplacer>> REPLACERS = Key.create("jst.post_process_replacers"); + + /** + * Process replacements in the file after {@link SourceTransformer transformers} have processed it. + */ + void process(Replacements replacements); + + @UnmodifiableView + static Map, PostProcessReplacer> getReplacers(PsiFile file) { + var rep = file.getUserData(REPLACERS); + return rep == null ? Map.of() : Collections.unmodifiableMap(rep); + } + + static T getOrCreateReplacer(PsiFile file, Class type, Function creator) { + var rep = file.getUserData(REPLACERS); + if (rep == null) { + rep = new IdentityHashMap<>(); + file.putUserData(REPLACERS, rep); + } + //noinspection unchecked + return (T)rep.computeIfAbsent(type, k -> creator.apply(file)); + } +} diff --git a/cli/src/main/java/net/neoforged/jst/cli/SourceFileProcessor.java b/cli/src/main/java/net/neoforged/jst/cli/SourceFileProcessor.java index 5e7dd06..cb8c9f1 100644 --- a/cli/src/main/java/net/neoforged/jst/cli/SourceFileProcessor.java +++ b/cli/src/main/java/net/neoforged/jst/cli/SourceFileProcessor.java @@ -6,6 +6,7 @@ import net.neoforged.jst.api.FileSink; import net.neoforged.jst.api.FileSource; import net.neoforged.jst.api.Logger; +import net.neoforged.jst.api.PostProcessReplacer; import net.neoforged.jst.api.Replacement; import net.neoforged.jst.api.Replacements; import net.neoforged.jst.api.SourceTransformer; @@ -153,6 +154,10 @@ private byte[] transformSource(VirtualFile contentRoot, FileEntry entry, List T parseSingleElement(@Language("JAVA") String javaCode, Class type) { + var file = ijEnv.parseFileFromMemory("Test.java", javaCode); + + var elements = PsiTreeUtil.collectElementsOfType(file, type); + assertEquals(1, elements.size()); + return elements.iterator().next(); + } +} diff --git a/interfaceinjection/src/main/java/net/neoforged/jst/interfaceinjection/InjectInterfacesVisitor.java b/interfaceinjection/src/main/java/net/neoforged/jst/interfaceinjection/InjectInterfacesVisitor.java index e24e572..f129240 100644 --- a/interfaceinjection/src/main/java/net/neoforged/jst/interfaceinjection/InjectInterfacesVisitor.java +++ b/interfaceinjection/src/main/java/net/neoforged/jst/interfaceinjection/InjectInterfacesVisitor.java @@ -8,6 +8,7 @@ import com.intellij.psi.PsiWhiteSpace; import com.intellij.psi.util.ClassUtil; import com.intellij.util.containers.MultiMap; +import net.neoforged.jst.api.ImportHelper; import net.neoforged.jst.api.Replacements; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -60,6 +61,8 @@ private void inject(PsiClass psiClass, Collection targets) { return; } + var imports = ImportHelper.get(psiClass.getContainingFile()); + var implementsList = psiClass.isInterface() ? psiClass.getExtendsList() : psiClass.getImplementsList(); var implementedInterfaces = Arrays.stream(implementsList.getReferencedTypes()) .map(PsiClassType::resolve) @@ -71,8 +74,8 @@ private void inject(PsiClass psiClass, Collection targets) { .distinct() .map(stubs::createStub) .filter(iface -> !implementedInterfaces.contains(iface.interfaceDeclaration())) - .map(StubStore.InterfaceInformation::toString) - .map(this::decorate) + .map(iface -> possiblyImport(imports, iface)) + .map(iface -> decorate(imports, iface)) .sorted(Comparator.naturalOrder()) .collect(Collectors.joining(", ")); @@ -94,10 +97,15 @@ private void inject(PsiClass psiClass, Collection targets) { } } - private String decorate(String iface) { + private String possiblyImport(@Nullable ImportHelper helper, StubStore.InterfaceInformation info) { + var interfaceName = helper == null ? info.interfaceDeclaration() : helper.importClass(info.interfaceDeclaration()); + return info.generics().isBlank() ? interfaceName : (interfaceName + "<" + info.generics() + ">"); + } + + private String decorate(@Nullable ImportHelper helper, String iface) { if (marker == null) { return iface; } - return "@" + marker + " " + iface; + return "@" + (helper == null ? marker : helper.importClass(marker)) + " " + iface; } } diff --git a/tests/data/interfaceinjection/additive_injection/expected/net/Example.java b/tests/data/interfaceinjection/additive_injection/expected/net/Example.java index 1552f09..6d73e0e 100644 --- a/tests/data/interfaceinjection/additive_injection/expected/net/Example.java +++ b/tests/data/interfaceinjection/additive_injection/expected/net/Example.java @@ -1,4 +1,6 @@ package net; -public class Example implements Runnable, com.example.InjectedInterface { +import com.example.InjectedInterface; + +public class Example implements Runnable, InjectedInterface { } diff --git a/tests/data/interfaceinjection/additive_injection/expected/net/Example2.java b/tests/data/interfaceinjection/additive_injection/expected/net/Example2.java index c1a7710..7766d9d 100644 --- a/tests/data/interfaceinjection/additive_injection/expected/net/Example2.java +++ b/tests/data/interfaceinjection/additive_injection/expected/net/Example2.java @@ -2,5 +2,7 @@ import java.util.*; -public class Example2 implements Runnable, Consumer, com.example.InjectedInterface { +import com.example.InjectedInterface; + +public class Example2 implements Runnable, Consumer, InjectedInterface { } diff --git a/tests/data/interfaceinjection/generics/expected/com/MyTarget.java b/tests/data/interfaceinjection/generics/expected/com/MyTarget.java index 220e439..dbbc084 100644 --- a/tests/data/interfaceinjection/generics/expected/com/MyTarget.java +++ b/tests/data/interfaceinjection/generics/expected/com/MyTarget.java @@ -1,4 +1,6 @@ package com; -public class MyTarget implements com.InjectedGeneric> { +import com.InjectedGeneric; + +public class MyTarget implements InjectedGeneric> { } diff --git a/tests/data/interfaceinjection/injected_marker/expected/SomeClass.java b/tests/data/interfaceinjection/injected_marker/expected/SomeClass.java index 76b9e8b..cd0c44e 100644 --- a/tests/data/interfaceinjection/injected_marker/expected/SomeClass.java +++ b/tests/data/interfaceinjection/injected_marker/expected/SomeClass.java @@ -1,2 +1,5 @@ -public class SomeClass implements @com.markers.InjectedMarker com.example.InjectedInterface { +import com.example.InjectedInterface; +import com.markers.InjectedMarker; + +public class SomeClass implements @InjectedMarker InjectedInterface { } diff --git a/tests/data/interfaceinjection/inner_stubs/expected/ExampleClass.java b/tests/data/interfaceinjection/inner_stubs/expected/ExampleClass.java index b6c1385..d0b871f 100644 --- a/tests/data/interfaceinjection/inner_stubs/expected/ExampleClass.java +++ b/tests/data/interfaceinjection/inner_stubs/expected/ExampleClass.java @@ -1,2 +1,5 @@ -public class ExampleClass implements com.example.InjectedInterface.Inner, com.example.InjectedInterface.Inner.SubInner { +import com.example.InjectedInterface.Inner; +import com.example.InjectedInterface.Inner.SubInner; + +public class ExampleClass implements Inner, SubInner { } diff --git a/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterface.java b/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterface.java index 2f6baad..4f512d2 100644 --- a/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterface.java +++ b/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterface.java @@ -1,4 +1,6 @@ package com.example; -public interface ExampleInterface extends com.example.InjectedInterface { +import com.example.InjectedInterface; + +public interface ExampleInterface extends InjectedInterface { } diff --git a/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterfaceAdditive.java b/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterfaceAdditive.java index c62ac5e..eae8f5f 100644 --- a/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterfaceAdditive.java +++ b/tests/data/interfaceinjection/interface_target/expected/com/example/ExampleInterfaceAdditive.java @@ -1,4 +1,6 @@ package com.example; -public interface ExampleInterfaceAdditive extends Runnable, com.example.InjectedInterface { +import com.example.InjectedInterface; + +public interface ExampleInterfaceAdditive extends Runnable, InjectedInterface { } diff --git a/tests/data/interfaceinjection/multiple_interfaces/expected/MyTarget.java b/tests/data/interfaceinjection/multiple_interfaces/expected/MyTarget.java index a502b5e..1a879b8 100644 --- a/tests/data/interfaceinjection/multiple_interfaces/expected/MyTarget.java +++ b/tests/data/interfaceinjection/multiple_interfaces/expected/MyTarget.java @@ -1,2 +1,5 @@ -public class MyTarget implements com.example.I1, com.example.I2 { +import com.example.I1; +import com.example.I2; + +public class MyTarget implements I1, I2 { } diff --git a/tests/data/interfaceinjection/simple_injection/expected/net/me/Example.java b/tests/data/interfaceinjection/simple_injection/expected/net/me/Example.java index 4ef26d8..17d3c71 100644 --- a/tests/data/interfaceinjection/simple_injection/expected/net/me/Example.java +++ b/tests/data/interfaceinjection/simple_injection/expected/net/me/Example.java @@ -1,4 +1,6 @@ package net.me; -public class Example implements com.example.InjectedInterface { +import com.example.InjectedInterface; + +public class Example implements InjectedInterface { } diff --git a/tests/data/interfaceinjection/simple_injection/expected/net/me/Example2.java b/tests/data/interfaceinjection/simple_injection/expected/net/me/Example2.java index ab2a748..d282a7e 100644 --- a/tests/data/interfaceinjection/simple_injection/expected/net/me/Example2.java +++ b/tests/data/interfaceinjection/simple_injection/expected/net/me/Example2.java @@ -1,4 +1,6 @@ package net.me; -public class Example2 extends Object implements com.example.InjectedInterface { +import com.example.InjectedInterface; + +public class Example2 extends Object implements InjectedInterface { } diff --git a/tests/data/interfaceinjection/stubs/expected/ExampleClass.java b/tests/data/interfaceinjection/stubs/expected/ExampleClass.java index 9cb6328..b5f091f 100644 --- a/tests/data/interfaceinjection/stubs/expected/ExampleClass.java +++ b/tests/data/interfaceinjection/stubs/expected/ExampleClass.java @@ -1,2 +1,5 @@ -public class ExampleClass implements InjectedRootInterface, com.example.II2, com.example.InjectedInterface { +import com.example.II2; +import com.example.InjectedInterface; + +public class ExampleClass implements II2, InjectedInterface, InjectedRootInterface { }