From d58f611274b91ff23e521db513107189b0991ca0 Mon Sep 17 00:00:00 2001 From: chiba Date: Fri, 28 Sep 2012 17:07:43 +0000 Subject: [PATCH] fixed JASSIST-160 by rewriting the whole javassist.bytecode.stackmap package. git-svn-id: http://anonsvn.jboss.org/repos/javassist/trunk@666 30ef5769-5b8d-40dd-aea6-55b5d6557bb3 --- src/main/javassist/bytecode/CodeIterator.java | 43 ++++++---- src/main/javassist/bytecode/StackMap.java | 26 ++++++ .../javassist/bytecode/StackMapTable.java | 81 +++++++++++++++-- src/test/javassist/bytecode/StackMapTest.java | 86 ++++++++++++++++++- 4 files changed, 212 insertions(+), 24 deletions(-) diff --git a/src/main/javassist/bytecode/CodeIterator.java b/src/main/javassist/bytecode/CodeIterator.java index 50dce6e3..c25e2e8f 100644 --- a/src/main/javassist/bytecode/CodeIterator.java +++ b/src/main/javassist/bytecode/CodeIterator.java @@ -725,10 +725,10 @@ public class CodeIterator implements Opcode { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 0, 0, 1, 1, 1, 1, 1, 1, 3, 3, - 3, 3, 3, 3, 3, 5, 0, 3, 2, 3, 1, 1, 3, 3, 1, 1, 0, 4, 3, 3, + 3, 3, 3, 3, 3, 5, 5, 3, 2, 3, 1, 1, 3, 3, 1, 1, 0, 4, 3, 3, 5, 5 }; - // 0 .. UNUSED (186), LOOKUPSWITCH, TABLESWITCH, WIDE + // 0 .. LOOKUPSWITCH, TABLESWITCH, WIDE /** * Calculates the index of the next opcode. @@ -1038,6 +1038,14 @@ public class CodeIterator implements Opcode { if (stack2 != null) stack2.shiftPc(where, gapLength, exclusive); } + + void shiftForSwitch(int where, int gapLength) throws BadBytecode { + if (stack != null) + stack.shiftForSwitch(where, gapLength); + + if (stack2 != null) + stack2.shiftForSwitch(where, gapLength); + } } /* @@ -1047,13 +1055,13 @@ public class CodeIterator implements Opcode { CodeAttribute ca, CodeAttribute.LdcEntry ldcs) throws BadBytecode { - ArrayList jumps = makeJumpList(code, code.length); + Pointers pointers = new Pointers(0, 0, 0, etable, ca); + ArrayList jumps = makeJumpList(code, code.length, pointers); while (ldcs != null) { addLdcW(ldcs, jumps); ldcs = ldcs.next; } - Pointers pointers = new Pointers(0, 0, 0, etable, ca); byte[] r = insertGap2w(code, 0, 0, false, jumps, pointers); return r; } @@ -1091,8 +1099,8 @@ public class CodeIterator implements Opcode { if (gapLength <= 0) return code; - ArrayList jumps = makeJumpList(code, code.length); Pointers pointers = new Pointers(currentPos, mark, where, etable, ca); + ArrayList jumps = makeJumpList(code, code.length, pointers); byte[] r = insertGap2w(code, where, gapLength, exclusive, jumps, pointers); currentPos = pointers.cursor; mark = pointers.mark; @@ -1152,7 +1160,7 @@ public class CodeIterator implements Opcode { return makeExapndedCode(code, jumps, where, gapLength); } - private static ArrayList makeJumpList(byte[] code, int endPos) + private static ArrayList makeJumpList(byte[] code, int endPos, Pointers ptrs) throws BadBytecode { ArrayList jumps = new ArrayList(); @@ -1191,7 +1199,7 @@ public class CodeIterator implements Opcode { i0 += 4; } - jumps.add(new Table(i, defaultbyte, lowbyte, highbyte, offsets)); + jumps.add(new Table(i, defaultbyte, lowbyte, highbyte, offsets, ptrs)); } else if (inst == LOOKUPSWITCH) { int i2 = (i & ~3) + 4; // 0-3 byte padding @@ -1206,7 +1214,7 @@ public class CodeIterator implements Opcode { i0 += 8; } - jumps.add(new Lookup(i, defaultbyte, matches, offsets)); + jumps.add(new Lookup(i, defaultbyte, matches, offsets, ptrs)); } } @@ -1300,7 +1308,7 @@ public class CodeIterator implements Opcode { int deltaSize() { return 0; } // newSize - oldSize // This returns the original instruction size. - abstract int write(int srcPos, byte[] code, int destPos, byte[] newcode); + abstract int write(int srcPos, byte[] code, int destPos, byte[] newcode) throws BadBytecode; } /* used by changeLdcToLdcW() and CodeAttribute.LdcEntry. @@ -1448,12 +1456,14 @@ public class CodeIterator implements Opcode { static abstract class Switcher extends Branch { int gap, defaultByte; int[] offsets; + Pointers pointers; - Switcher(int pos, int defaultByte, int[] offsets) { + Switcher(int pos, int defaultByte, int[] offsets, Pointers ptrs) { super(pos); this.gap = 3 - (pos & 3); this.defaultByte = defaultByte; this.offsets = offsets; + this.pointers = ptrs; } void shift(int where, int gapLength, boolean exclusive) { @@ -1481,7 +1491,7 @@ public class CodeIterator implements Opcode { return gap - (3 - (orgPos & 3)); } - int write(int src, byte[] code, int dest, byte[] newcode) { + int write(int src, byte[] code, int dest, byte[] newcode) throws BadBytecode { int padding = 3 - (pos & 3); int nops = gap - padding; int bytecodeSize = 5 + (3 - (orgPos & 3)) + tableSize(); @@ -1511,7 +1521,8 @@ public class CodeIterator implements Opcode { * dead code. It complicates the generation of StackMap and * StackMapTable. */ - void adjustOffsets(int size, int nops) { + void adjustOffsets(int size, int nops) throws BadBytecode { + pointers.shiftForSwitch(pos + size, nops); if (defaultByte == size) defaultByte -= nops; @@ -1524,8 +1535,8 @@ public class CodeIterator implements Opcode { static class Table extends Switcher { int low, high; - Table(int pos, int defaultByte, int low, int high, int[] offsets) { - super(pos, defaultByte, offsets); + Table(int pos, int defaultByte, int low, int high, int[] offsets, Pointers ptrs) { + super(pos, defaultByte, offsets, ptrs); this.low = low; this.high = high; } @@ -1549,8 +1560,8 @@ public class CodeIterator implements Opcode { static class Lookup extends Switcher { int[] matches; - Lookup(int pos, int defaultByte, int[] matches, int[] offsets) { - super(pos, defaultByte, offsets); + Lookup(int pos, int defaultByte, int[] matches, int[] offsets, Pointers ptrs) { + super(pos, defaultByte, offsets, ptrs); this.matches = matches; } diff --git a/src/main/javassist/bytecode/StackMap.java b/src/main/javassist/bytecode/StackMap.java index be54fda4..fe3655fd 100644 --- a/src/main/javassist/bytecode/StackMap.java +++ b/src/main/javassist/bytecode/StackMap.java @@ -398,6 +398,32 @@ public class StackMap extends AttributeInfo { } } + /** + * @see CodeIterator.Switcher#adjustOffsets(int, int) + */ + void shiftForSwitch(int where, int gapSize) throws BadBytecode { + new SwitchShifter(this, where, gapSize).visit(); + } + + static class SwitchShifter extends Walker { + private int where, gap; + + public SwitchShifter(StackMap smt, int where, int gap) { + super(smt); + this.where = where; + this.gap = gap; + } + + public int locals(int pos, int offset, int num) { + if (where == pos + offset) + ByteArray.write16bit(offset - gap, info, pos - 4); + else if (where == pos) + ByteArray.write16bit(offset + gap, info, pos - 4); + + return super.locals(pos, offset, num); + } + } + /** * Undocumented method. Do not use; internal-use only. * diff --git a/src/main/javassist/bytecode/StackMapTable.java b/src/main/javassist/bytecode/StackMapTable.java index 15726be2..4518ef36 100644 --- a/src/main/javassist/bytecode/StackMapTable.java +++ b/src/main/javassist/bytecode/StackMapTable.java @@ -796,10 +796,10 @@ public class StackMapTable extends AttributeInfo { static class Shifter extends Walker { private StackMapTable stackMap; - private int where, gap; - private int position; - private byte[] updatedInfo; - private boolean exclusive; + int where, gap; + int position; + byte[] updatedInfo; + boolean exclusive; public Shifter(StackMapTable smt, int where, int gap, boolean exclusive) { super(smt); @@ -825,7 +825,7 @@ public class StackMapTable extends AttributeInfo { update(pos, offsetDelta, 64, 247); } - private void update(int pos, int offsetDelta, int base, int entry) { + void update(int pos, int offsetDelta, int base, int entry) { int oldPos = position; position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1); boolean match; @@ -850,7 +850,7 @@ public class StackMapTable extends AttributeInfo { } } - private static byte[] insertGap(byte[] info, int where, int gap) { + static byte[] insertGap(byte[] info, int where, int gap) { int len = info.length; byte[] newinfo = new byte[len + gap]; for (int i = 0; i < len; i++) @@ -872,7 +872,7 @@ public class StackMapTable extends AttributeInfo { update(pos, offsetDelta); } - private void update(int pos, int offsetDelta) { + void update(int pos, int offsetDelta) { int oldPos = position; position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1); boolean match; @@ -889,6 +889,73 @@ public class StackMapTable extends AttributeInfo { } } + /** + * @see CodeIterator.Switcher#adjustOffsets(int, int) + */ + void shiftForSwitch(int where, int gapSize) throws BadBytecode { + new SwitchShifter(this, where, gapSize).doit(); + } + + static class SwitchShifter extends Shifter { + SwitchShifter(StackMapTable smt, int where, int gap) { + super(smt, where, gap, false); + } + + void update(int pos, int offsetDelta, int base, int entry) { + int oldPos = position; + position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1); + int newDelta = offsetDelta; + if (where == position) + newDelta = offsetDelta - gap; + else if (where == oldPos) + newDelta = offsetDelta + gap; + else + return; + + if (offsetDelta < 64) + if (newDelta < 64) + info[pos] = (byte)(newDelta + base); + else { + byte[] newinfo = insertGap(info, pos, 2); + newinfo[pos] = (byte)entry; + ByteArray.write16bit(newDelta, newinfo, pos + 1); + updatedInfo = newinfo; + } + else + if (newDelta < 64) { + byte[] newinfo = deleteGap(info, pos, 2); + newinfo[pos] = (byte)(newDelta + base); + updatedInfo = newinfo; + } + else + ByteArray.write16bit(newDelta, info, pos + 1); + } + + static byte[] deleteGap(byte[] info, int where, int gap) { + where += gap; + int len = info.length; + byte[] newinfo = new byte[len - gap]; + for (int i = 0; i < len; i++) + newinfo[i - (i < where ? 0 : gap)] = info[i]; + + return newinfo; + } + + void update(int pos, int offsetDelta) { + int oldPos = position; + position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1); + int newDelta = offsetDelta; + if (where == position) + newDelta = offsetDelta - gap; + else if (where == oldPos) + newDelta = offsetDelta + gap; + else + return; + + ByteArray.write16bit(newDelta, info, pos + 1); + } + } + /** * Undocumented method. Do not use; internal-use only. * diff --git a/src/test/javassist/bytecode/StackMapTest.java b/src/test/javassist/bytecode/StackMapTest.java index 3dba27d3..c7a77cb8 100644 --- a/src/test/javassist/bytecode/StackMapTest.java +++ b/src/test/javassist/bytecode/StackMapTest.java @@ -10,6 +10,7 @@ import java.io.PrintStream; import java.lang.reflect.Method; import javassist.ClassPool; +import javassist.CodeConverter; import javassist.CtClass; import javassist.CtMethod; import javassist.CtNewMethod; @@ -299,7 +300,7 @@ public class StackMapTest extends TestCase { rebuildStackMaps2(cc); cc.writeFile(); Object t1 = make(cc.getName()); - assertEquals(122, invoke(t1, "test")); + assertEquals(123, invoke(t1, "test")); } public static class T5 { @@ -339,6 +340,89 @@ public class StackMapTest extends TestCase { } } + public void testSwitchCase() throws Exception { + CtClass cc = loader.get("javassist.bytecode.StackMapTest$T7"); + // CodeConverter conv = new CodeConverter(); + // conv.replaceNew(cc, cc, "make2"); + // cc.instrument(conv); + StringBuffer sbuf = new StringBuffer("String s;"); + for (int i = 0; i < 130; i++) + sbuf.append("s =\"" + i + "\";"); + + cc.getDeclaredMethod("foo").insertBefore(sbuf.toString()); + cc.getDeclaredMethod("test2").setBody(loader.get("javassist.bytecode.StackMapTest$T8").getDeclaredMethod("test2"), null); + //rebuildStackMaps2(cc); + cc.writeFile(); + Object t1 = make(cc.getName()); + assertEquals(110, invoke(t1, "test")); + } + + public static class T7 { + int value = 1; + T7 t7; + public static T7 make2() { return null; } + public int foo() { return 1; } + public int test() { return test2(10); } + public int test2(int k) { return k; } + } + + public static class T8 { + public int test2(int k) { + String s = "abc"; + T7 t = k > 0 ? new T7() : new T7(); + switch (k) { + case 0: + t = new T7(); + k += t.value; + break; + case 10: + k += 100; + break; + } + return k; + } + } + + public void testSwitchCase2() throws Exception { + CtClass cc = loader.get("javassist.bytecode.StackMapTest$T7b"); + StringBuffer sbuf = new StringBuffer("String s;"); + for (int i = 0; i < 130; i++) + sbuf.append("s =\"" + i + "\";"); + + cc.getDeclaredMethod("foo").insertBefore(sbuf.toString()); + cc.getDeclaredMethod("test2").setBody(loader.get("javassist.bytecode.StackMapTest$T8b").getDeclaredMethod("test2"), null); + rebuildStackMaps2(cc); + cc.writeFile(); + Object t1 = make(cc.getName()); + assertEquals(110, invoke(t1, "test")); + } + + public static class T7b { + int value = 1; + T7b t7; + public static T7b make2() { return null; } + public int foo() { return 1; } + public int test() { return test2(10); } + public int test2(int k) { return k; } + } + + public static class T8b { + public int test2(int k) { + String s = "abc"; + T7b t = k > 0 ? new T7b() : new T7b(); + switch (k) { + case 0: + t = new T7b(); + k += t.value; + break; + case 10: + k += 100; + break; + } + return k; + } + } + public void tstCtClassType() throws Exception { ClassPool cp = ClassPool.getDefault(); CtClass cc = cp.get("javassist.CtClassType"); -- 2.39.5