aboutsummaryrefslogtreecommitdiffstats
path: root/src/test/test3/Enhancer.java
blob: 5c9d2a258f615e0438314c284efa9d0a1aac5df6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package test3;

import javassist.*;
import java.lang.reflect.Method;
import java.lang.reflect.Field;

/* Test code
 */
class EnhanceTest {
    public EnhanceTest() { super(); }
    public void foo(String s) { System.out.println(s); }
}

public class Enhancer {
    private ClassPool pool;
    private CtClass superClass;
    private CtClass thisClass;
    private Class thisJavaClass;
    private Interceptor interceptor;
    private int unique;

    private static final String INTERCEPTOR = "interceptor";

    /* Test method
     */
    public static void main(String[] args) throws Exception {
        Enhancer e = new Enhancer(test3.EnhanceTest.class);
        e.overrideAll();
        e.setCallback(new Interceptor() {
                public Object invoke(Object self, Method m, Object[] args)
                    throws Exception
                {
                    System.out.println("intercept: " + m);
                    return m.invoke(self, args);
                }
            });
        Class c = e.createClass();
        EnhanceTest obj = (EnhanceTest)c.newInstance();
        obj.foo("test");
    }

    public static interface Interceptor {
        Object invoke(Object self, Method m, Object[] args) throws Exception;
    }

    public Enhancer(Class clazz)
        throws CannotCompileException, NotFoundException
    {
        this(makeClassPool(clazz).get(clazz.getName()));
    }

    private static ClassPool makeClassPool(Class clazz) {
        ClassPool cp = new ClassPool();
        cp.appendSystemPath();
        cp.insertClassPath(new ClassClassPath(clazz));
        return cp;
    }

    public Enhancer(CtClass superClass)
        throws CannotCompileException, NotFoundException
    {
        this.pool = superClass.getClassPool();
        this.superClass = superClass;
        String name = superClass.getName() + "_proxy";
        thisClass = pool.makeClass(name);
        thisClass.setSuperclass(superClass);
        String src =
            "public static " + this.getClass().getName()
          + ".Interceptor " + INTERCEPTOR + ";";

        thisClass.addField(CtField.make(src, thisClass));
        this.thisJavaClass = null;
        unique = 0;
    }

    public void overrideAll()
        throws CannotCompileException, NotFoundException
    {
        CtMethod[] methods = superClass.getMethods();
        String delegatorNamePrefix = thisClass.makeUniqueName("d");
        for (int i = 0; i < methods.length; i++) {
            CtMethod m = methods[i];
            int mod = m.getModifiers();
            if (!Modifier.isFinal(mod) && !Modifier.isAbstract(mod)
                && !Modifier.isStatic(mod))
                override(m, delegatorNamePrefix + i);
        }
    }

    public void override(CtMethod m, String delegatorName)
        throws CannotCompileException, NotFoundException
    {
        String fieldName = "m" + unique++;
        thisClass.addField(
            CtField.make("private java.lang.reflect.Method "
                         + fieldName + ";", thisClass));
        CtMethod delegator = CtNewMethod.delegator(m, thisClass);
        delegator.setModifiers(Modifier.clear(delegator.getModifiers(),
                                              Modifier.NATIVE));
        delegator.setName(delegatorName);
        thisClass.addMethod(delegator);
        thisClass.addMethod(makeMethod(m, fieldName, delegatorName));
    }

    private CtMethod makeMethod(CtMethod m, String fieldName,
                                String delegatorName)
        throws CannotCompileException, NotFoundException
    {
        String factory = this.getClass().getName() + ".findMethod(this, \"" +
            delegatorName + "\");";
        String body
            = "{ if (" + fieldName + " == null) " +
                   fieldName + " = " + factory +
                 "return ($r)" + INTERCEPTOR + ".invoke(this, " + fieldName +
            				                ", $args); }";
        CtMethod m2 = CtNewMethod.make(m.getReturnType(),
                                       m.getName(),
                                       m.getParameterTypes(),
                                       m.getExceptionTypes(),
                                       body, thisClass);
        m2.setModifiers(Modifier.clear(m.getModifiers(),
                                       Modifier.NATIVE));
        return m2;
    }

    /* A runtime support routine called by an enhanced object.
     */
    public static Method findMethod(Object self, String name) {
        Method[] methods = self.getClass().getMethods();
        int n = methods.length;
        for (int i = 0; i < n; i++)
            if (methods[i].getName().equals(name))
                return methods[i];

        throw new RuntimeException("not found " + name
                                   + " in " + self.getClass());
    }

    public Class createClass() {
        if (thisJavaClass == null)
            try {
                thisClass.debugWriteFile();
                thisJavaClass = thisClass.toClass();
                setInterceptor();
            }
            catch (CannotCompileException e) {
                throw new RuntimeException(e);
            }

        return thisJavaClass;
    }

    private static void writeFile(CtClass cc) {
        try {
            cc.stopPruning(true);
            cc.writeFile();
            cc.defrost();
            cc.stopPruning(false);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void setCallback(Interceptor mi) {
        interceptor = mi;
        setInterceptor();
    }

    private void setInterceptor() {
        if (thisJavaClass != null && interceptor != null)
            try {
                Field f = thisJavaClass.getField(INTERCEPTOR);
                f.set(null, interceptor);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
    }
}