]> source.dussan.org Git - javassist.git/commitdiff
enables switch-case with string constants
authorchibash <chiba@javassist.org>
Sun, 9 Dec 2018 15:24:19 +0000 (00:24 +0900)
committerchibash <chiba@javassist.org>
Sun, 9 Dec 2018 15:24:19 +0000 (00:24 +0900)
javassist.jar
src/main/javassist/compiler/CodeGen.java
src/test/javassist/JvstTest5.java

index ce63f9d24a56c6eb97b7c6ee0c6148f02ae4c165..89913dc5cd5519b7a1aa676c4ba722c736fb2fa8 100644 (file)
Binary files a/javassist.jar and b/javassist.jar differ
index d4c748f81efe5b354eac81aebf72b121449160cd..3cb117197b1e3f6f0e6891e3abf1943a1f912982 100644 (file)
@@ -542,7 +542,23 @@ public abstract class CodeGen extends Visitor implements Opcode, TokenId {
     }
 
     private void atSwitchStmnt(Stmnt st) throws CompileError {
+        boolean isString = false;
+        if (typeChecker != null) {
+            doTypeCheck(st.head());
+            isString = typeChecker.exprType == TypeChecker.CLASS
+                       && typeChecker.arrayDim == 0
+                       && TypeChecker.jvmJavaLangString.equals(typeChecker.className);
+        }
+
         compileExpr(st.head());
+        int tmpVar = -1;
+        if (isString) {
+            tmpVar = getMaxLocals();
+            incMaxLocals(1);
+            bytecode.addAstore(tmpVar);
+            bytecode.addAload(tmpVar);
+            bytecode.addInvokevirtual(TypeChecker.jvmJavaLangString, "hashCode", "()I");
+        }
 
         List<Integer>  prevBreakList = breakList;
         breakList = new ArrayList<Integer>();
@@ -565,6 +581,7 @@ public abstract class CodeGen extends Visitor implements Opcode, TokenId {
         bytecode.addGap(npairs * 8);
 
         long[] pairs = new long[npairs];
+        ArrayList<Integer> gotoDefaults = new ArrayList<Integer>();
         int ipairs = 0;
         int defaultPc = -1;
         for (ASTList list = body; list != null; list = list.tail()) {
@@ -575,9 +592,18 @@ public abstract class CodeGen extends Visitor implements Opcode, TokenId {
             else if (op != CASE)
                 fatal();
             else {
+                int curPos = bytecode.currentPc();
+                long caseLabel;
+                if (isString) {
+                    // computeStringLabel() also adds bytecode as its side-effects.
+                    caseLabel = (long)computeStringLabel(label.head(), tmpVar, gotoDefaults);
+                }
+                else
+                    caseLabel = (long)computeLabel(label.head());
+
                 pairs[ipairs++]
-                    = ((long)computeLabel(label.head()) << 32) + 
-                      ((long)(bytecode.currentPc() - opcodePc) & 0xffffffff);
+                    = (caseLabel << 32) + 
+                      ((long)(curPos - opcodePc) & 0xffffffff);
             }
 
             hasReturned = false;
@@ -600,6 +626,8 @@ public abstract class CodeGen extends Visitor implements Opcode, TokenId {
             defaultPc = endPc;
 
         bytecode.write32bit(opcodePc2, defaultPc - opcodePc);
+        for (int addr: gotoDefaults)
+            bytecode.write16bit(addr, defaultPc - addr + 1);
 
         patchGoto(breakList, endPc);
         breakList = prevBreakList;
@@ -613,6 +641,26 @@ public abstract class CodeGen extends Visitor implements Opcode, TokenId {
         throw new CompileError("bad case label");
     }
 
+    private int computeStringLabel(ASTree expr, int tmpVar, List<Integer> gotoDefaults)
+        throws CompileError
+    {
+        doTypeCheck(expr);
+        expr = TypeChecker.stripPlusExpr(expr);
+        if (expr instanceof StringL) {
+            String label = ((StringL)expr).get();
+            bytecode.addAload(tmpVar);
+            bytecode.addLdc(label);
+            bytecode.addInvokevirtual(TypeChecker.jvmJavaLangString, "equals",
+                                      "(Ljava/lang/Object;)Z");
+            bytecode.addOpcode(IFEQ);
+            Integer pc = Integer.valueOf(bytecode.currentPc());
+            bytecode.addIndex(0);
+            gotoDefaults.add(pc);
+            return (int)label.hashCode();
+        }
+        throw new CompileError("bad case label");
+    }
+
     private void atBreakStmnt(Stmnt st, boolean notCont)
         throws CompileError
     {
index c5eff4d1ba76fecac2a9f3d6c5bc7727afa4f301..1a9bd66652c18b1fa74179db358158c8e785954b 100644 (file)
@@ -453,4 +453,35 @@ public class JvstTest5 extends JvstTestRoot {
         cc.getClassFile().compact();
         cc.toClass(test5.DefineClassCapability.class);
     }
+
+    public void testSwitchCaseWithStringConstant() throws Exception {
+        CtClass cc = sloader.get("test5.SwitchCase");
+        cc.addMethod(CtNewMethod.make(
+                "public int run() {" +
+                "    String s = \"foobar\";\n" +
+                "    switch (s) {\n" +
+                "    case STR1: return 1;\n" +
+                "    case \"foobar\": return 2;\n" +
+                "    default: return 3; }\n" +
+                "}\n", cc));
+        cc.writeFile();
+        Object obj = make(cc.getName());
+        assertEquals(2, invoke(obj, "run"));   
+    }
+
+    public void testSwitchCaseWithStringConstant2() throws Exception {
+        CtClass cc = sloader.makeClass("test5.SwitchCase2");
+        cc.addMethod(CtNewMethod.make(
+                "public int run() {" +
+                "    String s = \"foo\";\n" +
+                "    switch (s) {\n" +
+                "    case test5.SwitchCase.STR1: return 1;\n" +
+                "    case \"foobar\": return 2;\n" +
+                "    }\n" +
+                "    return 3;\n" +
+                "}\n", cc));
+        cc.writeFile();
+        Object obj = make(cc.getName());
+        assertEquals(1, invoke(obj, "run"));   
+    }
 }