]> source.dussan.org Git - javassist.git/commitdiff
support bootstrap method coping when using code coping 480/head
authorLiuChuliang 刘楚梁 <chuliang.liu@smil.com>
Thu, 18 Jan 2024 08:11:16 +0000 (16:11 +0800)
committerLiuChuliang 刘楚梁 <chuliang.liu@smil.com>
Thu, 18 Jan 2024 08:11:16 +0000 (16:11 +0800)
src/main/javassist/bytecode/BootstrapMethodsAttribute.java
src/main/javassist/bytecode/CodeAttribute.java
src/test/javassist/bytecode/BytecodeTest.java
src/test/test4/InvokeDynCopyDest.java [new file with mode: 0644]
src/test/test4/InvokeDynCopySrc.java [new file with mode: 0644]

index 94a0481fc1432705568af0d38ed38b3040afd51e..0fd04cf8af14e8e3bcc6606ba04336fa27eaee05 100644 (file)
@@ -2,6 +2,7 @@ package javassist.bytecode;
 
 import java.io.DataInputStream;
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Map;
 
 public class BootstrapMethodsAttribute extends AttributeInfo {
@@ -35,6 +36,26 @@ public class BootstrapMethodsAttribute extends AttributeInfo {
          * <code>bootstrap_arguments</code>.
          */
         public int[] arguments;
+
+        /**
+         * Makes a copy.  Class names are replaced according to the
+         *          * given <code>Map</code> object.
+         *
+         * @param srcCp     the constant pool table from the source
+         * @param destCp    the constant pool table used bt new copy
+         * @param classnames    pairs of replaced and substituted class names.
+         *
+         * @return new BootstrapMethod
+         */
+        protected BootstrapMethod copy(ConstPool srcCp, ConstPool destCp, Map<String,String> classnames) {
+            int newMethodRef = srcCp.copy(methodRef, destCp, classnames);
+            int[] newArguments = new int[arguments.length];
+
+            for (int i = 0; i < arguments.length; i++)
+                newArguments[i] = srcCp.copy(arguments[i], destCp, classnames);
+
+            return new BootstrapMethod(newMethodRef, newArguments);
+        }
     }
 
     BootstrapMethodsAttribute(ConstPool cp, int n, DataInputStream in)
@@ -51,25 +72,8 @@ public class BootstrapMethodsAttribute extends AttributeInfo {
      */
     public BootstrapMethodsAttribute(ConstPool cp, BootstrapMethod[] methods) {
         super(cp, tag);
-        int size = 2;
-        for (int i = 0; i < methods.length; i++)
-            size += 4 + methods[i].arguments.length * 2;
-
-        byte[] data = new byte[size];
-        ByteArray.write16bit(methods.length, data, 0);    // num_bootstrap_methods
-        int pos = 2;
-        for (int i = 0; i < methods.length; i++) {
-            ByteArray.write16bit(methods[i].methodRef, data, pos);
-            ByteArray.write16bit(methods[i].arguments.length, data, pos + 2);
-            int[] args = methods[i].arguments;
-            pos += 4;
-            for (int k = 0; k < args.length; k++) {
-                ByteArray.write16bit(args[k], data, pos);
-                pos += 2;
-            }
-        }
 
-        set(data);
+        set(convertMethodsToBytes(methods));
     }
 
     /**
@@ -113,12 +117,67 @@ public class BootstrapMethodsAttribute extends AttributeInfo {
         BootstrapMethod[] methods = getMethods();
         ConstPool thisCp = getConstPool();
         for (int i = 0; i < methods.length; i++) {
-            BootstrapMethod m = methods[i];
-            m.methodRef = thisCp.copy(m.methodRef, newCp, classnames);
-            for (int k = 0; k < m.arguments.length; k++)
-                m.arguments[k] = thisCp.copy(m.arguments[k], newCp, classnames);
+            methods[i] = methods[i].copy(thisCp, newCp, classnames);
         }
 
         return new BootstrapMethodsAttribute(newCp, methods);
     }
+
+    /**
+     * add bootstrap method from given <code>ConstPool</code> and <code>BootstrapMethod</code>,
+     * and add it to the specified index. Class names are replaced according to the
+     * given <code>Map</code> object.
+     *
+     * <p>
+     *      if the index less than 0 or large than the origin method length, then throw <code>RuntimeException</code>;<br>
+     *      if the index large or equals to 0 and less or equals to the origin method length,
+     *          then replace the origin method with the new <code>BootstrapMethod srcBm</code> ;<br>
+     *      if the index equals to the origin method length, then append the new <code>BootstrapMethod srcBm</code> at
+     *          the origin methods tail.
+     * </p>
+     *
+     * @param srcCp     the constant pool table of source.
+     * @param srcBm     the bootstrap method of source
+     * @param index     the new method index on bootstrap methods
+     * @param classnames        pairs of replaced and substituted
+     *                          class names.
+     */
+    public void addMethod(ConstPool srcCp, BootstrapMethod srcBm, int index, Map<String,String> classnames) {
+        BootstrapMethod[] methods = getMethods();
+
+        if (index < 0 || index > methods.length) {
+            throw new RuntimeException("index out of range");
+        }
+
+        if (index == methods.length) {
+            BootstrapMethod[] newBmArray = new BootstrapMethod[methods.length + 1];
+            System.arraycopy(methods, 0, newBmArray, 0, methods.length);
+            methods = newBmArray;
+        }
+
+        methods[index] = srcBm.copy(srcCp, getConstPool(), classnames);
+        set(convertMethodsToBytes(methods));
+    }
+
+    private static byte[] convertMethodsToBytes(BootstrapMethod[] methods) {
+        int size = 2;
+        for (int i = 0; i < methods.length; i++)
+            size += 4 + methods[i].arguments.length * 2;
+
+        byte[] data = new byte[size];
+        ByteArray.write16bit(methods.length, data, 0);    // num_bootstrap_methods
+        int pos = 2;
+        for (int i = 0; i < methods.length; i++) {
+            ByteArray.write16bit(methods[i].methodRef, data, pos);
+            ByteArray.write16bit(methods[i].arguments.length, data, pos + 2);
+            int[] args = methods[i].arguments;
+            pos += 4;
+            for (int k = 0; k < args.length; k++) {
+                ByteArray.write16bit(args[k], data, pos);
+                pos += 2;
+            }
+        }
+
+        return data;
+    }
 }
index 4c8ea2f270a17b73e7d0ad62b8596d00e07e89b3..091edf4cbf334ad40030efd1fa3db6dc199dddef 100644 (file)
@@ -16,6 +16,8 @@
 
 package javassist.bytecode;
 
+import javassist.*;
+
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
@@ -76,8 +78,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
      *                          class names.
      */
     private CodeAttribute(ConstPool cp, CodeAttribute src, Map<String,String> classnames)
-        throws BadBytecode
-    {
+            throws BadBytecode, NotFoundException, CannotCompileException {
         super(cp, tag);
 
         maxStack = src.getMaxStack();
@@ -139,6 +140,10 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
         }
         catch (BadBytecode e) {
             throw new RuntimeCopyException("bad bytecode. fatal?");
+        } catch (NotFoundException e) {
+            throw new RuntimeException(e);
+        } catch (CannotCompileException e) {
+            throw new RuntimeException(e);
         }
     }
 
@@ -324,7 +329,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
      *
      * @param smt       the stack map table added to this code attribute.
      *                  If it is null, a new stack map is not added.
-     *                  Only the old stack map is removed. 
+     *                  Only the old stack map is removed.
      */
     public void setAttribute(StackMapTable smt) {
         AttributeInfo.remove(attributes, StackMapTable.tag);
@@ -352,11 +357,11 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
      */
     private byte[] copyCode(ConstPool destCp, Map<String,String> classnames,
                             ExceptionTable etable, CodeAttribute destCa)
-        throws BadBytecode
-    {
+            throws BadBytecode, NotFoundException, CannotCompileException {
         int len = getCodeLength();
         byte[] newCode = new byte[len];
         destCa.info = newCode;
+
         LdcEntry ldc = copyCode(this.info, 0, len, this.getConstPool(),
                                 newCode, destCp, classnames);
         return LdcEntry.doit(newCode, ldc, etable, destCa);
@@ -364,9 +369,8 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
 
     private static LdcEntry copyCode(byte[] code, int beginPos, int endPos,
                                      ConstPool srcCp, byte[] newcode,
-                                     ConstPool destCp, Map<String,String> classnameMap)
-        throws BadBytecode
-    {
+                                                 ConstPool destCp, Map<String,String> classnameMap)
+            throws BadBytecode, NotFoundException, CannotCompileException {
         int i2, index;
         LdcEntry ldcEntry = null;
 
@@ -415,6 +419,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
             case INVOKEDYNAMIC :
                 copyConstPoolInfo(i + 1, code, srcCp, newcode, destCp,
                         classnameMap);
+                copyBootstrapMethod(srcCp, destCp, i + 1, code, newcode, classnameMap);
                 newcode[i + 3] = 0;
                 newcode[i + 4] = 0;
                 break;
@@ -434,6 +439,134 @@ public class CodeAttribute extends AttributeInfo implements Opcode {
         return ldcEntry;
     }
 
+    /**
+     *  Copy the Bootstrap method of the specified index referenced in the source <code>InvokeDynamic</code> directive
+     *  to the specified index in the destination Boostrap Attribute.<br>
+     *  if the Bootstrap Attribute does not exist in the destination class, create a new Bootstrap Attribute; <br>
+     *  if the destination Bootstrap Method already exists at the specified index method,
+     *      the method at that position will be overwritten, otherwise it will be added
+     *      at the end of the destination Bootstrap method.
+     *
+     * @param srcCp     the constant pool table of source
+     * @param destCp    the constant pool table of destination
+     * @param codeIndex     the index of the invoke dynamic first parameter in code array
+     * @param srcCode       the code array of source
+     * @param newCode       the code array of destination
+     * @param classnameMap  pairs of replaced and substituted class names.
+     *
+     * @throws NotFoundException        this exception thrown when the class
+     *                                      cannot be found in the default <code>ClassPool</code>
+     * @throws CannotCompileException   this exception thrown from the method
+     *                                       {@link #copyInvokeStaticMethod(CtClass, ConstPool,
+     *                                          BootstrapMethodsAttribute.BootstrapMethod, CtClass, Map)}
+     */
+    private static void copyBootstrapMethod(ConstPool srcCp, ConstPool destCp, int codeIndex, byte[] srcCode,
+                                            byte[] newCode, Map<String,String> classnameMap)
+            throws NotFoundException, CannotCompileException {
+        ClassPool classPool = ClassPool.getDefault();
+        CtClass srcCc = classPool.get(srcCp.getClassName());
+        CtClass destCc = classPool.get(destCp.getClassName());
+        ClassFile srcCf = srcCc.getClassFile();
+        ClassFile destCf = destCc.getClassFile();
+        BootstrapMethodsAttribute srcBma = (BootstrapMethodsAttribute)
+                srcCf.getAttribute(BootstrapMethodsAttribute.tag);
+
+        // if source class does not have bootstrap attribute then stop copy
+        if (srcBma == null) {
+            return;
+        }
+
+        BootstrapMethodsAttribute destBma = (BootstrapMethodsAttribute)
+                destCf.getAttribute(BootstrapMethodsAttribute.tag);
+
+        int srcCpIndex = ((srcCode[codeIndex] & 0xff) << 8) | (srcCode[codeIndex + 1] & 0xff);
+        int destCpIndex = ((newCode[codeIndex] & 0xff) << 8) | (newCode[codeIndex + 1] & 0xff);
+        int srcBmIndex = srcCp.getInvokeDynamicBootstrap(srcCpIndex);
+        int destBmIndex = destCp.getInvokeDynamicBootstrap(destCpIndex);
+
+        // if source class does not have bootstrap attribute, then create bootstrap attribute
+        if (destBma == null) {
+            destBma = new BootstrapMethodsAttribute(destCp,
+                    new BootstrapMethodsAttribute.BootstrapMethod[0]);
+            destCf.addAttribute(destBma);
+        }
+
+        BootstrapMethodsAttribute.BootstrapMethod srcBm = srcBma.getMethods()[srcBmIndex];
+        destBma.addMethod(srcCp, srcBm, destBmIndex, classnameMap);
+
+        copyInvokeStaticMethod(srcCc, srcCp, srcBm, destCc, classnameMap);
+    }
+
+    /**
+     *  Copy the static methods referenced by the bootstrap method in this class (such as some lambda methods).<br>
+     *  If the source method exists in the destination class, it will be ignored.
+     *
+     * @param srcCc     source class
+     * @param srcCp     constant pool table of source class
+     * @param srcBm     source method to be copied
+     * @param destCc    destination class
+     * @param classnameMap      irs of replaced and substituted class names.
+     *
+     * @throws CannotCompileException   thrown by {@link CtNewMethod#copy(CtMethod, CtClass, ClassMap)}
+     *                                          or{@link CtClass#addMethod(CtMethod)}
+     */
+    private static void copyInvokeStaticMethod(CtClass srcCc, ConstPool srcCp,
+                                               BootstrapMethodsAttribute.BootstrapMethod srcBm, CtClass destCc,
+                                               Map<String, String> classnameMap) throws CannotCompileException {
+        for (int argument : srcBm.arguments) {
+            ConstInfo constInfo = srcCp.getItem(argument);
+
+            if (!(constInfo instanceof MethodHandleInfo)) continue;
+
+            MethodHandleInfo methodHandleInfo = (MethodHandleInfo) constInfo;
+            if (ConstPool.REF_invokeStatic != methodHandleInfo.refKind) continue;
+
+            String methodRefClassName = srcCp.getMethodrefClassName(methodHandleInfo.refIndex);
+            if (methodRefClassName == null || !methodRefClassName.equals(srcCc.getName())) continue;
+
+            String staticMethodName = srcCp.getMethodrefName(methodHandleInfo.refIndex);
+            String staticMethodSignature = srcCp.getMethodrefType(methodHandleInfo.refIndex);
+            CtMethod srcMethod = getStaticCtMethod(srcCc, staticMethodName, staticMethodSignature);
+
+            if (!checkStaticMethodExisted(destCc, staticMethodName, staticMethodSignature)) {
+                ClassMap classMap = new ClassMap();
+                classMap.putAll(classnameMap);
+
+                CtMethod ctMethod = CtNewMethod.copy(srcMethod, destCc, classMap);
+                destCc.addMethod(ctMethod);
+            }
+        }
+    }
+
+    private static CtMethod getStaticCtMethod(CtClass ctClass, String staticMethodName, String staticMethodSignature) {
+        CtMethod srcMethod = null;
+        for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) {
+            if (Modifier.isStatic(declaredMethod.getModifiers())
+                    && declaredMethod.getName().equals(staticMethodName)
+                    && declaredMethod.getSignature().equals(staticMethodSignature)) {
+                srcMethod = declaredMethod;
+                break;
+            }
+        }
+
+        if (srcMethod == null) {
+            throw new RuntimeException("Can not found static method:" + staticMethodName);
+        }
+        return srcMethod;
+    }
+
+    private static boolean checkStaticMethodExisted(CtClass ctClass, String staticMethodName, String staticMethodSignature) {
+        for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) {
+            if (Modifier.isStatic(declaredMethod.getModifiers())
+                    && declaredMethod.getName().equals(staticMethodName)
+                    && declaredMethod.getSignature().equals(staticMethodSignature)) {
+                return true;
+            }
+        }
+
+        return false;
+    }
+
     private static void copyConstPoolInfo(int i, byte[] code, ConstPool srcCp,
                                           byte[] newcode, ConstPool destCp,
                                           Map<String,String> classnameMap) {
index eac420bc12a2e6f61a64aa7a428278378ca9d45f..68e3b2c263c6a1c1b6d7f099855df326fa4b39b8 100644 (file)
@@ -6,6 +6,7 @@ import junit.framework.*;
 import javassist.*;
 import javassist.bytecode.annotation.*;
 import javassist.bytecode.SignatureAttribute.*;
+import test4.InvokeDynCopyDest;
 
 @SuppressWarnings("unused")
 public class BytecodeTest extends TestCase {
@@ -461,19 +462,19 @@ public class BytecodeTest extends TestCase {
 
     public void testSignatureChange() throws Exception {
         changeMsig("<S:Ljava/lang/Object;>(TS;[TS;)Ljava/lang/Object", "java/lang/Object",
-                   "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object", "java/lang/Objec"); 
+                   "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object", "java/lang/Objec");
         changeMsig("<S:Ljava/lang/Object;>(TS;[TS;)TT;", "java/lang/Object",
-                   "<S:Ljava/lang/Objec;>(TS;[TS;)TT;", "java/lang/Objec"); 
+                   "<S:Ljava/lang/Objec;>(TS;[TS;)TT;", "java/lang/Objec");
         changeMsig("<S:Ljava/lang/Object;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Object",
-                   "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); 
+                   "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec");
         changeMsig("<S:Ljava/lang/Object;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object",
-                   "<S:Ljava/lang/Object2;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); 
+                   "<S:Ljava/lang/Object2;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2");
         changeMsig2("<S:Ljava/lang/Object;>(TS;[TS;)TT;", "java/lang/Object",
-                    "<S:Ljava/lang/Objec;>(TS;[TS;)TT;", "java/lang/Objec"); 
+                    "<S:Ljava/lang/Objec;>(TS;[TS;)TT;", "java/lang/Objec");
         changeMsig2("<S:Ljava/lang/Object;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Object",
-                    "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); 
+                    "<S:Ljava/lang/Objec;>(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec");
         changeMsig2("<S:Ljava/lang/Object;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object",
-                    "<S:Ljava/lang/Object2;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); 
+                    "<S:Ljava/lang/Object2;>(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2");
         String sig = "<T:Ljava/lang/Exception;>LPoi$Foo<Ljava/lang/String;>;LBar;LBar2;";
         String res = "<T:Ljava/lang/Exception;>LPoi$Foo<Ljava/lang/String2;>;LBar;LBar2;";
         changeMsig(sig, "java/lang/String", res, "java/lang/String2");
@@ -683,7 +684,7 @@ public class BytecodeTest extends TestCase {
         assertFalse(fi1.equals(fi3));
         assertFalse(fi1.equals(ci1));
         assertFalse(fi1.equals(null));
-       
+
         LongInfo li1 = new LongInfo(12345L, n++);
         LongInfo li2 = new LongInfo(12345L, n++);
         LongInfo li3 = new LongInfo(-12345L, n++);
@@ -834,6 +835,28 @@ public class BytecodeTest extends TestCase {
         assertEquals("(I)V", cPool2.getUtf8Info(cPool2.getMethodTypeInfo(mtIndex)));
     }
 
+    public void testInvokeDynamicWithCopy() throws Exception {
+        CtClass srcCc = loader.get("test4.InvokeDynCopySrc");
+        CtClass destCc = loader.get("test4.InvokeDynCopyDest");
+
+        // copy source constructor to dest
+        for (CtConstructor constructor : destCc.getConstructors()) {
+            for (CtConstructor srcClassConstructor : srcCc.getConstructors()) {
+                if (constructor.getSignature().equalsIgnoreCase(srcClassConstructor.getSignature())) {
+                    constructor.setBody(srcClassConstructor, null);
+                }
+            }
+        }
+
+        // set dest class method body by source class
+        destCc.getDeclaredMethod("getString").setBody(srcCc.getDeclaredMethod("getString"), new ClassMap());
+
+        Object destObj = (new Loader(loader)).loadClass(destCc.getName()).getConstructor().newInstance();
+
+        // if don't copy bootstrap method and static lambda method it will throw exception when invoke
+        assertEquals("hello", destObj.getClass().getMethod("getString").invoke(destObj));
+    }
+
     public static Test suite() {
         TestSuite suite = new TestSuite("Bytecode Tests");
         suite.addTestSuite(BytecodeTest.class);
diff --git a/src/test/test4/InvokeDynCopyDest.java b/src/test/test4/InvokeDynCopyDest.java
new file mode 100644 (file)
index 0000000..003c7c5
--- /dev/null
@@ -0,0 +1,11 @@
+package test4;
+
+public class InvokeDynCopyDest {
+    public InvokeDynCopyDest() {
+        System.out.println("my output:" + getString());
+    }
+
+    public String getString() {
+        return "dest";
+    }
+}
diff --git a/src/test/test4/InvokeDynCopySrc.java b/src/test/test4/InvokeDynCopySrc.java
new file mode 100644 (file)
index 0000000..83291f7
--- /dev/null
@@ -0,0 +1,17 @@
+package test4;
+
+import java.util.function.Supplier;
+
+public class InvokeDynCopySrc {
+    public InvokeDynCopySrc() {
+        System.out.println("source class:" + getString());
+    }
+
+    public String getString() {
+        Supplier<String> stringSupplier = () -> {
+            return "hello";
+        };
+
+        return stringSupplier.get();
+    }
+}