Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions api/src/main/java/net/neoforged/jst/api/ImportHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package net.neoforged.jst.api;

import com.intellij.psi.PsiClass;
import com.intellij.psi.PsiField;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiImportStatementBase;
import com.intellij.psi.PsiImportStaticStatement;
import com.intellij.psi.PsiJavaFile;
import com.intellij.psi.PsiMethod;
import com.intellij.psi.PsiModifier;
import com.intellij.psi.PsiPackage;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Helper class used to import classes while processing a source file.
* @see ImportHelper#get(PsiFile)
*/
public class ImportHelper implements PostProcessReplacer {
private final PsiJavaFile psiFile;
private final Map<String, String> importedNames = new HashMap<>();

private final Set<String> successfulImports = new HashSet<>();

public ImportHelper(PsiJavaFile psiFile) {
this.psiFile = psiFile;

if (psiFile.getPackageStatement() != null) {
var resolved = psiFile.getPackageStatement().getPackageReference().resolve();
// We cannot import a class with the name of a class in the package of the file
if (resolved instanceof PsiPackage pkg) {
for (PsiClass cls : pkg.getClasses()) {
importedNames.put(cls.getName(), cls.getQualifiedName());
}
}
}

if (psiFile.getImportList() != null) {
for (PsiImportStatementBase stmt : psiFile.getImportList().getImportStatements()) {
var res = stmt.resolve();
if (res instanceof PsiPackage pkg) {
// Wildcard package imports will reserve all names of top-level classes in the package
for (PsiClass cls : pkg.getClasses()) {
importedNames.put(cls.getName(), cls.getQualifiedName());
}
} else if (res instanceof PsiClass cls) {
importedNames.put(cls.getName(), cls.getQualifiedName());
}
}

for (PsiImportStaticStatement stmt : psiFile.getImportList().getImportStaticStatements()) {
var res = stmt.resolve();
if (res instanceof PsiMethod method) {
importedNames.put(method.getName(), method.getName());
} else if (res instanceof PsiField fld) {
importedNames.put(fld.getName(), fld.getName());
} else if (res instanceof PsiClass cls && stmt.isOnDemand()) {
// On-demand imports are static wildcard imports which will reserve the names of
// - all static methods available through the imported class
for (PsiMethod met : cls.getAllMethods()) {
if (met.getModifierList().hasModifierProperty(PsiModifier.STATIC)) {
importedNames.put(met.getName(), met.getName());
}
}

// - all fields available through the imported class
for (PsiField fld : cls.getAllFields()) {
if (fld.getModifierList() != null && fld.getModifierList().hasModifierProperty(PsiModifier.STATIC)) {
importedNames.put(fld.getName(), fld.getName());
}
}

// - all inner classes available through the imported class directly
for (PsiClass c : cls.getAllInnerClasses()) {
importedNames.put(c.getName(), c.getQualifiedName());
}

// Note: to avoid possible issues, none of the above check for visibility. We prefer to be more conservative to make sure the output sources compile
}
}
}
}

@VisibleForTesting
public boolean canImport(String name) {
return !importedNames.containsKey(name);
}

/**
* Attempts to import the given fully qualified class name, returning a reference to it which is either
* its short name (if an import is successful) or the qualified name if not.
*/
public String importClass(String cls) {
var clsByDot = cls.split("\\.");
// We do not try to import classes in the default package or classes already imported
if (clsByDot.length == 1 || successfulImports.contains(cls)) {
return clsByDot[clsByDot.length - 1];
}
// We also do not want to import classes under java.lang.*
else if (clsByDot.length == 3 && clsByDot[0].equals("java") && clsByDot[1].equals("lang")) {
return clsByDot[2];
}

var name = clsByDot[clsByDot.length - 1];

if (Objects.equals(importedNames.get(name), cls)) {
return name;
}

if (canImport(name)) {
successfulImports.add(cls);
return name;
}

return cls;
}

@Override
public void process(Replacements replacements) {
if (successfulImports.isEmpty()) return;

var insertion = successfulImports.stream()
.sorted()
.map(s -> "import " + s + ";")
.collect(Collectors.joining("\n"));

if (psiFile.getImportList() != null && psiFile.getImportList().getLastChild() != null) {
var lastImport = psiFile.getImportList().getLastChild();
replacements.insertAfter(lastImport, "\n\n" + insertion);
} else {
replacements.insertBefore(psiFile.getClasses()[0], insertion + "\n\n");
}
}

@Nullable
public static ImportHelper get(PsiFile file) {
return file instanceof PsiJavaFile j ? get(j) : null;
}

public static ImportHelper get(PsiJavaFile file) {
return PostProcessReplacer.getOrCreateReplacer(file, ImportHelper.class, k -> new ImportHelper(file));
}
}
38 changes: 38 additions & 0 deletions api/src/main/java/net/neoforged/jst/api/PostProcessReplacer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package net.neoforged.jst.api;

import com.intellij.openapi.util.Key;
import com.intellij.psi.PsiFile;
import org.jetbrains.annotations.UnmodifiableView;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.function.Function;

/**
* A replacer linked to a {@link PsiFile} will run and collect replacements after all {@link SourceTransformer transformers} have processed the file.
*/
public interface PostProcessReplacer {
Key<Map<Class<?>, PostProcessReplacer>> REPLACERS = Key.create("jst.post_process_replacers");

/**
* Process replacements in the file after {@link SourceTransformer transformers} have processed it.
*/
void process(Replacements replacements);

@UnmodifiableView
static Map<Class<?>, PostProcessReplacer> getReplacers(PsiFile file) {
var rep = file.getUserData(REPLACERS);
return rep == null ? Map.of() : Collections.unmodifiableMap(rep);
}

static <T extends PostProcessReplacer> T getOrCreateReplacer(PsiFile file, Class<T> type, Function<PsiFile, T> creator) {
var rep = file.getUserData(REPLACERS);
if (rep == null) {
rep = new IdentityHashMap<>();
file.putUserData(REPLACERS, rep);
}
//noinspection unchecked
return (T)rep.computeIfAbsent(type, k -> creator.apply(file));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import net.neoforged.jst.api.FileSink;
import net.neoforged.jst.api.FileSource;
import net.neoforged.jst.api.Logger;
import net.neoforged.jst.api.PostProcessReplacer;
import net.neoforged.jst.api.Replacement;
import net.neoforged.jst.api.Replacements;
import net.neoforged.jst.api.SourceTransformer;
Expand Down Expand Up @@ -153,6 +154,10 @@ private byte[] transformSource(VirtualFile contentRoot, FileEntry entry, List<So
transformer.visitFile(psiFile, replacements);
}

for (PostProcessReplacer rep : PostProcessReplacer.getReplacers(psiFile).values()) {
rep.process(replacements);
}

var readOnlyReplacements = Collections.unmodifiableList(replacementsList);
boolean success = true;
for (var transformer : transformers) {
Expand Down
133 changes: 133 additions & 0 deletions cli/src/test/java/net/neoforged/jst/cli/ImportHelperTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package net.neoforged.jst.cli;

import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiClass;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiJavaFile;
import com.intellij.psi.PsiMethod;
import com.intellij.psi.util.PsiTreeUtil;
import net.neoforged.jst.api.ImportHelper;
import net.neoforged.jst.api.Logger;
import net.neoforged.jst.api.Replacements;
import net.neoforged.jst.cli.intellij.IntelliJEnvironmentImpl;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.IOException;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.assertj.core.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class ImportHelperTest {
static IntelliJEnvironmentImpl ijEnv;

@BeforeAll
static void setUp() throws IOException {
ijEnv = new IntelliJEnvironmentImpl(new Logger(null, null));
ijEnv.addCurrentJdkToClassPath();
}

@AfterAll
static void tearDown() throws IOException {
ijEnv.close();
}

@Test
public void testSimpleImports() {
var helper = getImportHelper("""
import java.util.Collection;
import java.lang.annotation.Retention;
import java.util.concurrent.atomic.AtomicReference;""");

assertFalse(helper.canImport("Collection"), "Collection can wrongly be imported");
assertFalse(helper.canImport("Retention"), "Retention can wrongly be imported");
assertFalse(helper.canImport("AtomicReference"), "AtomicReference can wrongly be imported");

assertTrue(helper.canImport("MyRandomClass"), "Cannot import a non-reserved name");
}

@Test
public void testWildcardImports() {
var helper = getImportHelper("""
import java.util.concurrent.*;""");

assertFalse(helper.canImport("Future"), "Future can wrongly be imported");
assertFalse(helper.canImport("Executor"), "Executor can wrongly be imported");

assertTrue(helper.canImport("ThisWillNotExist"), "Cannot import a non-reserved name");
}

@Test
public void testStaticImports() {
var helper = getImportHelper("""
import static java.util.Spliterators.emptyDoubleSpliterator;
import static java.util.Collections.*;""");

assertFalse(helper.canImport("emptyDoubleSpliterator"), "emptyDoubleSpliterator can wrongly be imported");

assertFalse(helper.canImport("min"), "min can wrongly be imported");
assertFalse(helper.canImport("checkedSortedMap"), "checkedSortedMap can wrongly be imported");
assertFalse(helper.canImport("EMPTY_LIST"), "EMPTY_LIST can wrongly be imported");

assertTrue(helper.canImport("ThisWillNotExist"), "Cannot import a non-reserved name");
}

@Test
void testReplace() {
var file = parseSingleFile("""
package java.lang.annotation;

import java.util.*;

class MyClass {
}""");

var helper = ImportHelper.get(file);

assertEquals("HelloWorld", helper.importClass("com.hello.world.HelloWorld"));

assertEquals("Annotation", helper.importClass("java.lang.annotation.Annotation"));
assertEquals("com.hello.world.Annotation", helper.importClass("com.hello.world.Annotation"));

assertEquals("List", helper.importClass("java.util.List"));
assertEquals("com.hello.world.List", helper.importClass("com.hello.world.List"));

assertEquals("Thing", helper.importClass("a.b.c.Thing"));

var rep = new Replacements();
helper.process(rep);

assertThat(rep.apply(file.getText()))
.isEqualToNormalizingNewlines("""
package java.lang.annotation;

import java.util.*;

import a.b.c.Thing;
import com.hello.world.HelloWorld;

class MyClass {
}""");
}

private ImportHelper getImportHelper(@Language("JAVA") String javaCode) {
var file = parseSingleFile(javaCode);
return new ImportHelper(file);
}

private PsiJavaFile parseSingleFile(@Language("JAVA") String javaCode) {
return parseSingleElement(javaCode, PsiJavaFile.class);
}

private <T extends PsiElement> T parseSingleElement(@Language("JAVA") String javaCode, Class<T> type) {
var file = ijEnv.parseFileFromMemory("Test.java", javaCode);

var elements = PsiTreeUtil.collectElementsOfType(file, type);
assertEquals(1, elements.size());
return elements.iterator().next();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.intellij.psi.PsiWhiteSpace;
import com.intellij.psi.util.ClassUtil;
import com.intellij.util.containers.MultiMap;
import net.neoforged.jst.api.ImportHelper;
import net.neoforged.jst.api.Replacements;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -60,6 +61,8 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
return;
}

var imports = ImportHelper.get(psiClass.getContainingFile());

var implementsList = psiClass.isInterface() ? psiClass.getExtendsList() : psiClass.getImplementsList();
var implementedInterfaces = Arrays.stream(implementsList.getReferencedTypes())
.map(PsiClassType::resolve)
Expand All @@ -71,8 +74,8 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
.distinct()
.map(stubs::createStub)
.filter(iface -> !implementedInterfaces.contains(iface.interfaceDeclaration()))
.map(StubStore.InterfaceInformation::toString)
.map(this::decorate)
.map(iface -> possiblyImport(imports, iface))
.map(iface -> decorate(imports, iface))
.sorted(Comparator.naturalOrder())
.collect(Collectors.joining(", "));

Expand All @@ -94,10 +97,15 @@ private void inject(PsiClass psiClass, Collection<String> targets) {
}
}

private String decorate(String iface) {
private String possiblyImport(@Nullable ImportHelper helper, StubStore.InterfaceInformation info) {
var interfaceName = helper == null ? info.interfaceDeclaration() : helper.importClass(info.interfaceDeclaration());
return info.generics().isBlank() ? interfaceName : (interfaceName + "<" + info.generics() + ">");
}

private String decorate(@Nullable ImportHelper helper, String iface) {
if (marker == null) {
return iface;
}
return "@" + marker + " " + iface;
return "@" + (helper == null ? marker : helper.importClass(marker)) + " " + iface;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
package net;

public class Example implements Runnable, com.example.InjectedInterface {
import com.example.InjectedInterface;

public class Example implements Runnable, InjectedInterface {
}
Loading