From 9adc1a3526d58f7123a0837d4c5c68eb6a7099dd Mon Sep 17 00:00:00 2001 From: =?utf8?q?LiuChuliang=20=E5=88=98=E6=A5=9A=E6=A2=81?= Date: Thu, 18 Jan 2024 16:11:16 +0800 Subject: [PATCH] support bootstrap method coping when using code coping --- .../bytecode/BootstrapMethodsAttribute.java | 103 +++++++++--- .../javassist/bytecode/CodeAttribute.java | 149 +++++++++++++++++- src/test/javassist/bytecode/BytecodeTest.java | 39 ++++- src/test/test4/InvokeDynCopyDest.java | 11 ++ src/test/test4/InvokeDynCopySrc.java | 17 ++ 5 files changed, 281 insertions(+), 38 deletions(-) create mode 100644 src/test/test4/InvokeDynCopyDest.java create mode 100644 src/test/test4/InvokeDynCopySrc.java diff --git a/src/main/javassist/bytecode/BootstrapMethodsAttribute.java b/src/main/javassist/bytecode/BootstrapMethodsAttribute.java index 94a0481f..0fd04cf8 100644 --- a/src/main/javassist/bytecode/BootstrapMethodsAttribute.java +++ b/src/main/javassist/bytecode/BootstrapMethodsAttribute.java @@ -2,6 +2,7 @@ package javassist.bytecode; import java.io.DataInputStream; import java.io.IOException; +import java.util.Arrays; import java.util.Map; public class BootstrapMethodsAttribute extends AttributeInfo { @@ -35,6 +36,26 @@ public class BootstrapMethodsAttribute extends AttributeInfo { * bootstrap_arguments. */ public int[] arguments; + + /** + * Makes a copy. Class names are replaced according to the + * * given Map object. + * + * @param srcCp the constant pool table from the source + * @param destCp the constant pool table used bt new copy + * @param classnames pairs of replaced and substituted class names. + * + * @return new BootstrapMethod + */ + protected BootstrapMethod copy(ConstPool srcCp, ConstPool destCp, Map classnames) { + int newMethodRef = srcCp.copy(methodRef, destCp, classnames); + int[] newArguments = new int[arguments.length]; + + for (int i = 0; i < arguments.length; i++) + newArguments[i] = srcCp.copy(arguments[i], destCp, classnames); + + return new BootstrapMethod(newMethodRef, newArguments); + } } BootstrapMethodsAttribute(ConstPool cp, int n, DataInputStream in) @@ -51,25 +72,8 @@ public class BootstrapMethodsAttribute extends AttributeInfo { */ public BootstrapMethodsAttribute(ConstPool cp, BootstrapMethod[] methods) { super(cp, tag); - int size = 2; - for (int i = 0; i < methods.length; i++) - size += 4 + methods[i].arguments.length * 2; - - byte[] data = new byte[size]; - ByteArray.write16bit(methods.length, data, 0); // num_bootstrap_methods - int pos = 2; - for (int i = 0; i < methods.length; i++) { - ByteArray.write16bit(methods[i].methodRef, data, pos); - ByteArray.write16bit(methods[i].arguments.length, data, pos + 2); - int[] args = methods[i].arguments; - pos += 4; - for (int k = 0; k < args.length; k++) { - ByteArray.write16bit(args[k], data, pos); - pos += 2; - } - } - set(data); + set(convertMethodsToBytes(methods)); } /** @@ -113,12 +117,67 @@ public class BootstrapMethodsAttribute extends AttributeInfo { BootstrapMethod[] methods = getMethods(); ConstPool thisCp = getConstPool(); for (int i = 0; i < methods.length; i++) { - BootstrapMethod m = methods[i]; - m.methodRef = thisCp.copy(m.methodRef, newCp, classnames); - for (int k = 0; k < m.arguments.length; k++) - m.arguments[k] = thisCp.copy(m.arguments[k], newCp, classnames); + methods[i] = methods[i].copy(thisCp, newCp, classnames); } return new BootstrapMethodsAttribute(newCp, methods); } + + /** + * add bootstrap method from given ConstPool and BootstrapMethod, + * and add it to the specified index. Class names are replaced according to the + * given Map object. + * + *

+ * if the index less than 0 or large than the origin method length, then throw RuntimeException;
+ * if the index large or equals to 0 and less or equals to the origin method length, + * then replace the origin method with the new BootstrapMethod srcBm ;
+ * if the index equals to the origin method length, then append the new BootstrapMethod srcBm at + * the origin methods tail. + *

+ * + * @param srcCp the constant pool table of source. + * @param srcBm the bootstrap method of source + * @param index the new method index on bootstrap methods + * @param classnames pairs of replaced and substituted + * class names. + */ + public void addMethod(ConstPool srcCp, BootstrapMethod srcBm, int index, Map classnames) { + BootstrapMethod[] methods = getMethods(); + + if (index < 0 || index > methods.length) { + throw new RuntimeException("index out of range"); + } + + if (index == methods.length) { + BootstrapMethod[] newBmArray = new BootstrapMethod[methods.length + 1]; + System.arraycopy(methods, 0, newBmArray, 0, methods.length); + methods = newBmArray; + } + + methods[index] = srcBm.copy(srcCp, getConstPool(), classnames); + set(convertMethodsToBytes(methods)); + } + + private static byte[] convertMethodsToBytes(BootstrapMethod[] methods) { + int size = 2; + for (int i = 0; i < methods.length; i++) + size += 4 + methods[i].arguments.length * 2; + + byte[] data = new byte[size]; + ByteArray.write16bit(methods.length, data, 0); // num_bootstrap_methods + int pos = 2; + for (int i = 0; i < methods.length; i++) { + ByteArray.write16bit(methods[i].methodRef, data, pos); + ByteArray.write16bit(methods[i].arguments.length, data, pos + 2); + int[] args = methods[i].arguments; + pos += 4; + for (int k = 0; k < args.length; k++) { + ByteArray.write16bit(args[k], data, pos); + pos += 2; + } + } + + return data; + } } diff --git a/src/main/javassist/bytecode/CodeAttribute.java b/src/main/javassist/bytecode/CodeAttribute.java index 4c8ea2f2..091edf4c 100644 --- a/src/main/javassist/bytecode/CodeAttribute.java +++ b/src/main/javassist/bytecode/CodeAttribute.java @@ -16,6 +16,8 @@ package javassist.bytecode; +import javassist.*; + import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -76,8 +78,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode { * class names. */ private CodeAttribute(ConstPool cp, CodeAttribute src, Map classnames) - throws BadBytecode - { + throws BadBytecode, NotFoundException, CannotCompileException { super(cp, tag); maxStack = src.getMaxStack(); @@ -139,6 +140,10 @@ public class CodeAttribute extends AttributeInfo implements Opcode { } catch (BadBytecode e) { throw new RuntimeCopyException("bad bytecode. fatal?"); + } catch (NotFoundException e) { + throw new RuntimeException(e); + } catch (CannotCompileException e) { + throw new RuntimeException(e); } } @@ -324,7 +329,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode { * * @param smt the stack map table added to this code attribute. * If it is null, a new stack map is not added. - * Only the old stack map is removed. + * Only the old stack map is removed. */ public void setAttribute(StackMapTable smt) { AttributeInfo.remove(attributes, StackMapTable.tag); @@ -352,11 +357,11 @@ public class CodeAttribute extends AttributeInfo implements Opcode { */ private byte[] copyCode(ConstPool destCp, Map classnames, ExceptionTable etable, CodeAttribute destCa) - throws BadBytecode - { + throws BadBytecode, NotFoundException, CannotCompileException { int len = getCodeLength(); byte[] newCode = new byte[len]; destCa.info = newCode; + LdcEntry ldc = copyCode(this.info, 0, len, this.getConstPool(), newCode, destCp, classnames); return LdcEntry.doit(newCode, ldc, etable, destCa); @@ -364,9 +369,8 @@ public class CodeAttribute extends AttributeInfo implements Opcode { private static LdcEntry copyCode(byte[] code, int beginPos, int endPos, ConstPool srcCp, byte[] newcode, - ConstPool destCp, Map classnameMap) - throws BadBytecode - { + ConstPool destCp, Map classnameMap) + throws BadBytecode, NotFoundException, CannotCompileException { int i2, index; LdcEntry ldcEntry = null; @@ -415,6 +419,7 @@ public class CodeAttribute extends AttributeInfo implements Opcode { case INVOKEDYNAMIC : copyConstPoolInfo(i + 1, code, srcCp, newcode, destCp, classnameMap); + copyBootstrapMethod(srcCp, destCp, i + 1, code, newcode, classnameMap); newcode[i + 3] = 0; newcode[i + 4] = 0; break; @@ -434,6 +439,134 @@ public class CodeAttribute extends AttributeInfo implements Opcode { return ldcEntry; } + /** + * Copy the Bootstrap method of the specified index referenced in the source InvokeDynamic directive + * to the specified index in the destination Boostrap Attribute.
+ * if the Bootstrap Attribute does not exist in the destination class, create a new Bootstrap Attribute;
+ * if the destination Bootstrap Method already exists at the specified index method, + * the method at that position will be overwritten, otherwise it will be added + * at the end of the destination Bootstrap method. + * + * @param srcCp the constant pool table of source + * @param destCp the constant pool table of destination + * @param codeIndex the index of the invoke dynamic first parameter in code array + * @param srcCode the code array of source + * @param newCode the code array of destination + * @param classnameMap pairs of replaced and substituted class names. + * + * @throws NotFoundException this exception thrown when the class + * cannot be found in the default ClassPool + * @throws CannotCompileException this exception thrown from the method + * {@link #copyInvokeStaticMethod(CtClass, ConstPool, + * BootstrapMethodsAttribute.BootstrapMethod, CtClass, Map)} + */ + private static void copyBootstrapMethod(ConstPool srcCp, ConstPool destCp, int codeIndex, byte[] srcCode, + byte[] newCode, Map classnameMap) + throws NotFoundException, CannotCompileException { + ClassPool classPool = ClassPool.getDefault(); + CtClass srcCc = classPool.get(srcCp.getClassName()); + CtClass destCc = classPool.get(destCp.getClassName()); + ClassFile srcCf = srcCc.getClassFile(); + ClassFile destCf = destCc.getClassFile(); + BootstrapMethodsAttribute srcBma = (BootstrapMethodsAttribute) + srcCf.getAttribute(BootstrapMethodsAttribute.tag); + + // if source class does not have bootstrap attribute then stop copy + if (srcBma == null) { + return; + } + + BootstrapMethodsAttribute destBma = (BootstrapMethodsAttribute) + destCf.getAttribute(BootstrapMethodsAttribute.tag); + + int srcCpIndex = ((srcCode[codeIndex] & 0xff) << 8) | (srcCode[codeIndex + 1] & 0xff); + int destCpIndex = ((newCode[codeIndex] & 0xff) << 8) | (newCode[codeIndex + 1] & 0xff); + int srcBmIndex = srcCp.getInvokeDynamicBootstrap(srcCpIndex); + int destBmIndex = destCp.getInvokeDynamicBootstrap(destCpIndex); + + // if source class does not have bootstrap attribute, then create bootstrap attribute + if (destBma == null) { + destBma = new BootstrapMethodsAttribute(destCp, + new BootstrapMethodsAttribute.BootstrapMethod[0]); + destCf.addAttribute(destBma); + } + + BootstrapMethodsAttribute.BootstrapMethod srcBm = srcBma.getMethods()[srcBmIndex]; + destBma.addMethod(srcCp, srcBm, destBmIndex, classnameMap); + + copyInvokeStaticMethod(srcCc, srcCp, srcBm, destCc, classnameMap); + } + + /** + * Copy the static methods referenced by the bootstrap method in this class (such as some lambda methods).
+ * If the source method exists in the destination class, it will be ignored. + * + * @param srcCc source class + * @param srcCp constant pool table of source class + * @param srcBm source method to be copied + * @param destCc destination class + * @param classnameMap irs of replaced and substituted class names. + * + * @throws CannotCompileException thrown by {@link CtNewMethod#copy(CtMethod, CtClass, ClassMap)} + * or{@link CtClass#addMethod(CtMethod)} + */ + private static void copyInvokeStaticMethod(CtClass srcCc, ConstPool srcCp, + BootstrapMethodsAttribute.BootstrapMethod srcBm, CtClass destCc, + Map classnameMap) throws CannotCompileException { + for (int argument : srcBm.arguments) { + ConstInfo constInfo = srcCp.getItem(argument); + + if (!(constInfo instanceof MethodHandleInfo)) continue; + + MethodHandleInfo methodHandleInfo = (MethodHandleInfo) constInfo; + if (ConstPool.REF_invokeStatic != methodHandleInfo.refKind) continue; + + String methodRefClassName = srcCp.getMethodrefClassName(methodHandleInfo.refIndex); + if (methodRefClassName == null || !methodRefClassName.equals(srcCc.getName())) continue; + + String staticMethodName = srcCp.getMethodrefName(methodHandleInfo.refIndex); + String staticMethodSignature = srcCp.getMethodrefType(methodHandleInfo.refIndex); + CtMethod srcMethod = getStaticCtMethod(srcCc, staticMethodName, staticMethodSignature); + + if (!checkStaticMethodExisted(destCc, staticMethodName, staticMethodSignature)) { + ClassMap classMap = new ClassMap(); + classMap.putAll(classnameMap); + + CtMethod ctMethod = CtNewMethod.copy(srcMethod, destCc, classMap); + destCc.addMethod(ctMethod); + } + } + } + + private static CtMethod getStaticCtMethod(CtClass ctClass, String staticMethodName, String staticMethodSignature) { + CtMethod srcMethod = null; + for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) { + if (Modifier.isStatic(declaredMethod.getModifiers()) + && declaredMethod.getName().equals(staticMethodName) + && declaredMethod.getSignature().equals(staticMethodSignature)) { + srcMethod = declaredMethod; + break; + } + } + + if (srcMethod == null) { + throw new RuntimeException("Can not found static method:" + staticMethodName); + } + return srcMethod; + } + + private static boolean checkStaticMethodExisted(CtClass ctClass, String staticMethodName, String staticMethodSignature) { + for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) { + if (Modifier.isStatic(declaredMethod.getModifiers()) + && declaredMethod.getName().equals(staticMethodName) + && declaredMethod.getSignature().equals(staticMethodSignature)) { + return true; + } + } + + return false; + } + private static void copyConstPoolInfo(int i, byte[] code, ConstPool srcCp, byte[] newcode, ConstPool destCp, Map classnameMap) { diff --git a/src/test/javassist/bytecode/BytecodeTest.java b/src/test/javassist/bytecode/BytecodeTest.java index eac420bc..68e3b2c2 100644 --- a/src/test/javassist/bytecode/BytecodeTest.java +++ b/src/test/javassist/bytecode/BytecodeTest.java @@ -6,6 +6,7 @@ import junit.framework.*; import javassist.*; import javassist.bytecode.annotation.*; import javassist.bytecode.SignatureAttribute.*; +import test4.InvokeDynCopyDest; @SuppressWarnings("unused") public class BytecodeTest extends TestCase { @@ -461,19 +462,19 @@ public class BytecodeTest extends TestCase { public void testSignatureChange() throws Exception { changeMsig("(TS;[TS;)Ljava/lang/Object", "java/lang/Object", - "(TS;[TS;)Ljava/lang/Object", "java/lang/Objec"); + "(TS;[TS;)Ljava/lang/Object", "java/lang/Objec"); changeMsig("(TS;[TS;)TT;", "java/lang/Object", - "(TS;[TS;)TT;", "java/lang/Objec"); + "(TS;[TS;)TT;", "java/lang/Objec"); changeMsig("(TS;[TS;)Ljava/lang/Object2;", "java/lang/Object", - "(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); + "(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); changeMsig("(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object", - "(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); + "(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); changeMsig2("(TS;[TS;)TT;", "java/lang/Object", - "(TS;[TS;)TT;", "java/lang/Objec"); + "(TS;[TS;)TT;", "java/lang/Objec"); changeMsig2("(TS;[TS;)Ljava/lang/Object2;", "java/lang/Object", - "(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); + "(TS;[TS;)Ljava/lang/Object2;", "java/lang/Objec"); changeMsig2("(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object", - "(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); + "(TS;[TS;)Ljava/lang/Objec;", "java/lang/Object2"); String sig = "LPoi$Foo;LBar;LBar2;"; String res = "LPoi$Foo;LBar;LBar2;"; changeMsig(sig, "java/lang/String", res, "java/lang/String2"); @@ -683,7 +684,7 @@ public class BytecodeTest extends TestCase { assertFalse(fi1.equals(fi3)); assertFalse(fi1.equals(ci1)); assertFalse(fi1.equals(null)); - + LongInfo li1 = new LongInfo(12345L, n++); LongInfo li2 = new LongInfo(12345L, n++); LongInfo li3 = new LongInfo(-12345L, n++); @@ -834,6 +835,28 @@ public class BytecodeTest extends TestCase { assertEquals("(I)V", cPool2.getUtf8Info(cPool2.getMethodTypeInfo(mtIndex))); } + public void testInvokeDynamicWithCopy() throws Exception { + CtClass srcCc = loader.get("test4.InvokeDynCopySrc"); + CtClass destCc = loader.get("test4.InvokeDynCopyDest"); + + // copy source constructor to dest + for (CtConstructor constructor : destCc.getConstructors()) { + for (CtConstructor srcClassConstructor : srcCc.getConstructors()) { + if (constructor.getSignature().equalsIgnoreCase(srcClassConstructor.getSignature())) { + constructor.setBody(srcClassConstructor, null); + } + } + } + + // set dest class method body by source class + destCc.getDeclaredMethod("getString").setBody(srcCc.getDeclaredMethod("getString"), new ClassMap()); + + Object destObj = (new Loader(loader)).loadClass(destCc.getName()).getConstructor().newInstance(); + + // if don't copy bootstrap method and static lambda method it will throw exception when invoke + assertEquals("hello", destObj.getClass().getMethod("getString").invoke(destObj)); + } + public static Test suite() { TestSuite suite = new TestSuite("Bytecode Tests"); suite.addTestSuite(BytecodeTest.class); diff --git a/src/test/test4/InvokeDynCopyDest.java b/src/test/test4/InvokeDynCopyDest.java new file mode 100644 index 00000000..003c7c53 --- /dev/null +++ b/src/test/test4/InvokeDynCopyDest.java @@ -0,0 +1,11 @@ +package test4; + +public class InvokeDynCopyDest { + public InvokeDynCopyDest() { + System.out.println("my output:" + getString()); + } + + public String getString() { + return "dest"; + } +} diff --git a/src/test/test4/InvokeDynCopySrc.java b/src/test/test4/InvokeDynCopySrc.java new file mode 100644 index 00000000..83291f7b --- /dev/null +++ b/src/test/test4/InvokeDynCopySrc.java @@ -0,0 +1,17 @@ +package test4; + +import java.util.function.Supplier; + +public class InvokeDynCopySrc { + public InvokeDynCopySrc() { + System.out.println("source class:" + getString()); + } + + public String getString() { + Supplier stringSupplier = () -> { + return "hello"; + }; + + return stringSupplier.get(); + } +} -- 2.39.5