]> source.dussan.org Git - javassist.git/commitdiff
fixed JASSIST-160 by rewriting the whole javassist.bytecode.stackmap package.
authorchiba <chiba@30ef5769-5b8d-40dd-aea6-55b5d6557bb3>
Fri, 28 Sep 2012 17:07:43 +0000 (17:07 +0000)
committerchiba <chiba@30ef5769-5b8d-40dd-aea6-55b5d6557bb3>
Fri, 28 Sep 2012 17:07:43 +0000 (17:07 +0000)
git-svn-id: http://anonsvn.jboss.org/repos/javassist/trunk@666 30ef5769-5b8d-40dd-aea6-55b5d6557bb3

src/main/javassist/bytecode/CodeIterator.java
src/main/javassist/bytecode/StackMap.java
src/main/javassist/bytecode/StackMapTable.java
src/test/javassist/bytecode/StackMapTest.java

index 50dce6e354b33ac386b5e9cf366accaa3b249524..c25e2e8fc8d8df0ea8600dfa00705f1df89d76ae 100644 (file)
@@ -725,10 +725,10 @@ public class CodeIterator implements Opcode {
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3,
         3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 0, 0, 1, 1, 1, 1, 1, 1, 3, 3,
-        3, 3, 3, 3, 3, 5, 0, 3, 2, 3, 1, 1, 3, 3, 1, 1, 0, 4, 3, 3,
+        3, 3, 3, 3, 3, 5, 5, 3, 2, 3, 1, 1, 3, 3, 1, 1, 0, 4, 3, 3,
         5, 5
     };
-    // 0 .. UNUSED (186), LOOKUPSWITCH, TABLESWITCH, WIDE
+    // 0 .. LOOKUPSWITCH, TABLESWITCH, WIDE
 
     /**
      * Calculates the index of the next opcode.
@@ -1038,6 +1038,14 @@ public class CodeIterator implements Opcode {
             if (stack2 != null)
                 stack2.shiftPc(where, gapLength, exclusive);
         }
+
+        void shiftForSwitch(int where, int gapLength) throws BadBytecode {
+            if (stack != null)
+                stack.shiftForSwitch(where, gapLength);
+
+            if (stack2 != null)
+                stack2.shiftForSwitch(where, gapLength);
+        }
     }
 
     /*
@@ -1047,13 +1055,13 @@ public class CodeIterator implements Opcode {
                                   CodeAttribute ca, CodeAttribute.LdcEntry ldcs)
         throws BadBytecode
     {
-        ArrayList jumps = makeJumpList(code, code.length);
+        Pointers pointers = new Pointers(0, 0, 0, etable, ca);
+        ArrayList jumps = makeJumpList(code, code.length, pointers);
         while (ldcs != null) {
             addLdcW(ldcs, jumps);
             ldcs = ldcs.next;
         }
 
-        Pointers pointers = new Pointers(0, 0, 0, etable, ca);
         byte[] r = insertGap2w(code, 0, 0, false, jumps, pointers);
         return r;
     }
@@ -1091,8 +1099,8 @@ public class CodeIterator implements Opcode {
         if (gapLength <= 0)
             return code;
 
-        ArrayList jumps = makeJumpList(code, code.length);
         Pointers pointers = new Pointers(currentPos, mark, where, etable, ca);
+        ArrayList jumps = makeJumpList(code, code.length, pointers);
         byte[] r = insertGap2w(code, where, gapLength, exclusive, jumps, pointers);
         currentPos = pointers.cursor;
         mark = pointers.mark;
@@ -1152,7 +1160,7 @@ public class CodeIterator implements Opcode {
         return makeExapndedCode(code, jumps, where, gapLength);
     }
 
-    private static ArrayList makeJumpList(byte[] code, int endPos)
+    private static ArrayList makeJumpList(byte[] code, int endPos, Pointers ptrs)
         throws BadBytecode
     {
         ArrayList jumps = new ArrayList();
@@ -1191,7 +1199,7 @@ public class CodeIterator implements Opcode {
                     i0 += 4;
                 }
 
-                jumps.add(new Table(i, defaultbyte, lowbyte, highbyte, offsets));
+                jumps.add(new Table(i, defaultbyte, lowbyte, highbyte, offsets, ptrs));
             }
             else if (inst == LOOKUPSWITCH) {
                 int i2 = (i & ~3) + 4;  // 0-3 byte padding
@@ -1206,7 +1214,7 @@ public class CodeIterator implements Opcode {
                     i0 += 8;
                 }
 
-                jumps.add(new Lookup(i, defaultbyte, matches, offsets));
+                jumps.add(new Lookup(i, defaultbyte, matches, offsets, ptrs));
             }
         }
 
@@ -1300,7 +1308,7 @@ public class CodeIterator implements Opcode {
         int deltaSize() { return 0; }   // newSize - oldSize
 
         // This returns the original instruction size.
-        abstract int write(int srcPos, byte[] code, int destPos, byte[] newcode);
+        abstract int write(int srcPos, byte[] code, int destPos, byte[] newcode) throws BadBytecode;
     }
 
     /* used by changeLdcToLdcW() and CodeAttribute.LdcEntry.
@@ -1448,12 +1456,14 @@ public class CodeIterator implements Opcode {
     static abstract class Switcher extends Branch {
         int gap, defaultByte;
         int[] offsets;
+        Pointers pointers;
 
-        Switcher(int pos, int defaultByte, int[] offsets) {
+        Switcher(int pos, int defaultByte, int[] offsets, Pointers ptrs) {
             super(pos);
             this.gap = 3 - (pos & 3);
             this.defaultByte = defaultByte;
             this.offsets = offsets;
+            this.pointers = ptrs;
         }
 
         void shift(int where, int gapLength, boolean exclusive) {
@@ -1481,7 +1491,7 @@ public class CodeIterator implements Opcode {
             return gap - (3 - (orgPos & 3));
         }
 
-        int write(int src, byte[] code, int dest, byte[] newcode) {
+        int write(int src, byte[] code, int dest, byte[] newcode) throws BadBytecode {
             int padding = 3 - (pos & 3);
             int nops = gap - padding;
             int bytecodeSize = 5 + (3 - (orgPos & 3)) + tableSize();
@@ -1511,7 +1521,8 @@ public class CodeIterator implements Opcode {
          * dead code.  It complicates the generation of StackMap and
          * StackMapTable.
          */
-        void adjustOffsets(int size, int nops) {
+        void adjustOffsets(int size, int nops) throws BadBytecode {
+            pointers.shiftForSwitch(pos + size, nops);
             if (defaultByte == size)
                 defaultByte -= nops;
 
@@ -1524,8 +1535,8 @@ public class CodeIterator implements Opcode {
     static class Table extends Switcher {
         int low, high;
 
-        Table(int pos, int defaultByte, int low, int high, int[] offsets) {
-            super(pos, defaultByte, offsets);
+        Table(int pos, int defaultByte, int low, int high, int[] offsets, Pointers ptrs) {
+            super(pos, defaultByte, offsets, ptrs);
             this.low = low;
             this.high = high;
         }
@@ -1549,8 +1560,8 @@ public class CodeIterator implements Opcode {
     static class Lookup extends Switcher {
         int[] matches;
 
-        Lookup(int pos, int defaultByte, int[] matches, int[] offsets) {
-            super(pos, defaultByte, offsets);
+        Lookup(int pos, int defaultByte, int[] matches, int[] offsets, Pointers ptrs) {
+            super(pos, defaultByte, offsets, ptrs);
             this.matches = matches;
         }
 
index be54fda4e054422d7f754cfc4e79b723904db92f..fe3655fd6cad1b85e7ad31a3a0e4c29a9b8c72c0 100644 (file)
@@ -398,6 +398,32 @@ public class StackMap extends AttributeInfo {
         }
     }
 
+    /**
+     * @see CodeIterator.Switcher#adjustOffsets(int, int)
+     */
+    void shiftForSwitch(int where, int gapSize) throws BadBytecode {
+        new SwitchShifter(this, where, gapSize).visit();
+    }
+
+    static class SwitchShifter extends Walker {
+        private int where, gap;
+
+        public SwitchShifter(StackMap smt, int where, int gap) {
+            super(smt);
+            this.where = where;
+            this.gap = gap;
+        }
+
+        public int locals(int pos, int offset, int num) {
+            if (where == pos + offset)
+                ByteArray.write16bit(offset - gap, info, pos - 4);
+            else if (where == pos)
+                ByteArray.write16bit(offset + gap, info, pos - 4);
+
+            return super.locals(pos, offset, num);
+        }
+    }
+
     /**
      * Undocumented method.  Do not use; internal-use only.
      *
index 15726be281876f859651766022c66b90a5a61c78..4518ef36b231ce92715317e61b40d6cae91af9ea 100644 (file)
@@ -796,10 +796,10 @@ public class StackMapTable extends AttributeInfo {
 
     static class Shifter extends Walker {
         private StackMapTable stackMap;
-        private int where, gap;
-        private int position;
-        private byte[] updatedInfo;
-        private boolean exclusive;
+        int where, gap;
+        int position;
+        byte[] updatedInfo;
+        boolean exclusive;
 
         public Shifter(StackMapTable smt, int where, int gap, boolean exclusive) {
             super(smt);
@@ -825,7 +825,7 @@ public class StackMapTable extends AttributeInfo {
             update(pos, offsetDelta, 64, 247);
         }
 
-        private void update(int pos, int offsetDelta, int base, int entry) {
+        void update(int pos, int offsetDelta, int base, int entry) {
             int oldPos = position;
             position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1);
             boolean match;
@@ -850,7 +850,7 @@ public class StackMapTable extends AttributeInfo {
             }
         }
 
-        private static byte[] insertGap(byte[] info, int where, int gap) {
+        static byte[] insertGap(byte[] info, int where, int gap) {
             int len = info.length;
             byte[] newinfo = new byte[len + gap];
             for (int i = 0; i < len; i++)
@@ -872,7 +872,7 @@ public class StackMapTable extends AttributeInfo {
             update(pos, offsetDelta);
         }
 
-        private void update(int pos, int offsetDelta) {
+        void update(int pos, int offsetDelta) {
             int oldPos = position;
             position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1);
             boolean match;
@@ -889,6 +889,73 @@ public class StackMapTable extends AttributeInfo {
         }
     }
 
+    /**
+     * @see CodeIterator.Switcher#adjustOffsets(int, int)
+     */
+    void shiftForSwitch(int where, int gapSize) throws BadBytecode {
+        new SwitchShifter(this, where, gapSize).doit();
+    }
+
+    static class SwitchShifter extends Shifter {
+        SwitchShifter(StackMapTable smt, int where, int gap) {
+            super(smt, where, gap, false);
+        }
+
+        void update(int pos, int offsetDelta, int base, int entry) {
+            int oldPos = position;
+            position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1);
+            int newDelta = offsetDelta;
+            if (where == position)
+                newDelta = offsetDelta - gap;
+            else if (where == oldPos)
+                newDelta = offsetDelta + gap;
+            else
+                return;
+
+            if (offsetDelta < 64)
+                if (newDelta < 64)
+                    info[pos] = (byte)(newDelta + base);
+                else {
+                    byte[] newinfo = insertGap(info, pos, 2);
+                    newinfo[pos] = (byte)entry;
+                    ByteArray.write16bit(newDelta, newinfo, pos + 1);
+                    updatedInfo = newinfo;
+                }
+            else
+                if (newDelta < 64) {
+                    byte[] newinfo = deleteGap(info, pos, 2);
+                    newinfo[pos] = (byte)(newDelta + base);
+                    updatedInfo = newinfo;
+                }
+                else
+                    ByteArray.write16bit(newDelta, info, pos + 1);
+        }
+
+        static byte[] deleteGap(byte[] info, int where, int gap) {
+            where += gap;
+            int len = info.length;
+            byte[] newinfo = new byte[len - gap];
+            for (int i = 0; i < len; i++)
+                newinfo[i - (i < where ? 0 : gap)] = info[i];
+
+            return newinfo;
+        }
+
+        void update(int pos, int offsetDelta) {
+            int oldPos = position;
+            position = oldPos + offsetDelta + (oldPos == 0 ? 0 : 1);
+            int newDelta = offsetDelta;
+            if (where == position)
+                newDelta = offsetDelta - gap;
+            else if (where == oldPos)
+                newDelta = offsetDelta + gap;
+            else
+                return;
+
+            ByteArray.write16bit(newDelta, info, pos + 1);
+        }
+    }
+
     /**
      * Undocumented method.  Do not use; internal-use only.
      *
index 3dba27d36f31283f089f2d2b843fa5b2375f135d..c7a77cb8d3b3efe60d65f0713bfeb08cf724d5ad 100644 (file)
@@ -10,6 +10,7 @@ import java.io.PrintStream;
 import java.lang.reflect.Method;
 
 import javassist.ClassPool;
+import javassist.CodeConverter;
 import javassist.CtClass;
 import javassist.CtMethod;
 import javassist.CtNewMethod;
@@ -299,7 +300,7 @@ public class StackMapTest extends TestCase {
         rebuildStackMaps2(cc);
         cc.writeFile();
         Object t1 = make(cc.getName());
-        assertEquals(122, invoke(t1, "test"));
+        assertEquals(123, invoke(t1, "test"));
     }
 
     public static class T5 {
@@ -339,6 +340,89 @@ public class StackMapTest extends TestCase {
         }
     }
 
+    public void testSwitchCase() throws Exception {
+        CtClass cc = loader.get("javassist.bytecode.StackMapTest$T7");
+        // CodeConverter conv = new CodeConverter();
+        // conv.replaceNew(cc, cc, "make2");
+        // cc.instrument(conv);
+        StringBuffer sbuf = new StringBuffer("String s;");
+        for (int i = 0; i < 130; i++)
+            sbuf.append("s =\"" + i + "\";");
+
+        cc.getDeclaredMethod("foo").insertBefore(sbuf.toString());
+        cc.getDeclaredMethod("test2").setBody(loader.get("javassist.bytecode.StackMapTest$T8").getDeclaredMethod("test2"), null);
+        //rebuildStackMaps2(cc);
+        cc.writeFile();
+        Object t1 = make(cc.getName());
+        assertEquals(110, invoke(t1, "test"));
+    }
+
+    public static class T7 {
+        int value = 1;
+        T7 t7;
+        public static T7 make2() { return null; }
+        public int foo() { return 1; }
+        public int test() { return test2(10); }
+        public int test2(int k) { return k; }
+    }
+
+    public static class T8 {
+        public int test2(int k) {
+            String s = "abc";
+            T7 t = k > 0 ? new T7() : new T7();
+            switch (k) {
+            case 0:
+                t = new T7();
+                k += t.value;
+                break;
+            case 10:
+                k += 100;
+                break;
+            }
+            return k;
+        }
+    }
+
+    public void testSwitchCase2() throws Exception {
+        CtClass cc = loader.get("javassist.bytecode.StackMapTest$T7b");
+        StringBuffer sbuf = new StringBuffer("String s;");
+        for (int i = 0; i < 130; i++)
+            sbuf.append("s =\"" + i + "\";");
+
+        cc.getDeclaredMethod("foo").insertBefore(sbuf.toString());
+        cc.getDeclaredMethod("test2").setBody(loader.get("javassist.bytecode.StackMapTest$T8b").getDeclaredMethod("test2"), null);
+        rebuildStackMaps2(cc);
+        cc.writeFile();
+        Object t1 = make(cc.getName());
+        assertEquals(110, invoke(t1, "test"));
+    }
+
+    public static class T7b {
+        int value = 1;
+        T7b t7;
+        public static T7b make2() { return null; }
+        public int foo() { return 1; }
+        public int test() { return test2(10); }
+        public int test2(int k) { return k; }
+    }
+
+    public static class T8b {
+        public int test2(int k) {
+            String s = "abc";
+            T7b t = k > 0 ? new T7b() : new T7b();
+            switch (k) {
+            case 0:
+                t = new T7b();
+                k += t.value;
+                break;
+            case 10:
+                k += 100;
+                break;
+            }
+            return k;
+        }
+    }
+
     public void tstCtClassType() throws Exception {
         ClassPool cp = ClassPool.getDefault();
         CtClass cc = cp.get("javassist.CtClassType");