/*
 * Copyright 2004-2014 the Seasar Foundation and the Others.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the License for the specific language
 * governing permissions and limitations under the License.
 */

package org.seasar.framework.container.hotdeploy;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.net.URL;

import org.apache.logging.log4j.LogManager;

/**
 * HOT deploy用の {@link ClassLoader}です。
 *
 * @author higa
 */
public abstract class HotdeployClassLoader extends ClassLoader {

    /** FIND_LOADED_CLASS_METHOD */
    private static final Method FIND_LOADED_CLASS_METHOD = getFindLoadedClassMethod();

    /**
     * {@link HotdeployClassLoader}を作成します。
     *
     * @param classLoader ClassLoader
     */
    public HotdeployClassLoader(final ClassLoader classLoader) {
        super(classLoader);
    }

    /**
     * getFindLoadedClassMethod
     * @return Method
     */
    private static Method getFindLoadedClassMethod() {
        try {
            final Method method = ClassLoader.class.getDeclaredMethod(
                    "findLoadedClass", String.class);
            method.setAccessible(true);
            return method;
        } catch (final NoSuchMethodException ex) {
            throw new RuntimeException(ex);
        }
    }

    /**
     * @see java.lang.ClassLoader#loadClass(java.lang.String, boolean)
     */
    @Override
    public Class<?> loadClass(final String className, final boolean resolve)
            throws ClassNotFoundException {
        if (isTargetClass(className)) {
            Class<?> clazz = findLoadedClass(className);
            if (clazz != null) {
                return clazz;
            }
            clazz = findLoadedClass(getParent(), className);
            if (clazz != null) {
                return clazz;
            }
            clazz = defineClass(className, resolve);
            if (clazz != null) {
                return clazz;
            }
        }
        return super.loadClass(className, resolve);
    }

    /**
     * 指定のクラスローダまたはその祖先の暮らすローダが、 このバイナリ名を持つクラスの起動ローダとしてJava仮想マシンにより記録されていた場合は、
     * 指定されたバイナリ名を持つクラスを返します。 記録されていなかった場合は<code>null</code>を返します。
     *
     * @param classLoader クラスローダ
     * @param className クラスのバイナリ名
     * @return <code>Class</code>オブジェクト。クラスがロードされていない場合は<code>null</code>
     */
    private static Class<?> findLoadedClass(final ClassLoader classLoader, final String className) {
        for (ClassLoader loader = classLoader; loader != null; loader = loader.getParent()) {
            final Object clazz = invoke(loader, className);
            if (clazz != null) {
                return (Class<?>) clazz;
            }
        }
        return null;
    }

    /**
     * {@link Method#invoke(Object, Object[])}の例外処理をラップします。
     *
     * @param target Object
     * @param args Object
     * @return 戻り値
     */
    private static Object invoke(final Object target, final Object... args) {
        try {
            return FIND_LOADED_CLASS_METHOD.invoke(target, args);
        } catch (final ReflectiveOperationException ex) {
            final Throwable t = ex.getCause();
            if (t instanceof RuntimeException) {
                throw (RuntimeException) t;
            } else if (t instanceof Error) {
                throw (Error) t;
            }
            throw new RuntimeException(ex);
        }
    }

    /**
     * {@link Class}を定義します。
     *
     * @param className クラス名
     * @param resolve resolveフラグ
     * @return {@link Class}
     */
    private Class<?> defineClass(final String className, final boolean resolve) {
        final byte[] bytes = getClassBytes(className);
        if (0 < bytes.length) {
            final Class<?> clazz = defineClass(className, bytes);
            if (resolve) {
                resolveClass(clazz);
            }
            return clazz;
        }
        return null;
    }

    /**
     * クラスバイト配列取得
     * @param className クラス名
     * @return バイト配列
     */
    private byte[] getClassBytes(final String className) {
        byte[] bytes = new byte[0];
        final String path = className.replace(".", "/") + ".class";
        final URL url = Thread.currentThread().getContextClassLoader().getResource(path);
        if (url != null) {
            try (InputStream is = url.openStream()) {
                bytes = toByteArray(is);
            } catch (final IOException e) {
                LogManager.getLogger().warn(e.getMessage(), e);
            }
        }
        return bytes;
    }

    /**
     * {@link InputStream}からbyteの配列を取得します。
     *
     * @param is InputStream
     * @return byteの配列
     * @throws IOException IOException
     */
    private static byte[] toByteArray(final InputStream is) throws IOException {
        final byte[] buf = new byte[8192];
        final ByteArrayOutputStream baos = new ByteArrayOutputStream();
        int n;
        while ((n = is.read(buf, 0, buf.length)) != -1) {
            baos.write(buf, 0, n);
        }
        return baos.toByteArray();
    }

    /**
     * {@link Class}を定義します。
     *
     * @param className クラス名
     * @param bytes buffer
     * @return {@link Class}
     */
    protected Class<?> defineClass(final String className, final byte[] bytes) {
        return defineClass(className, bytes, 0, bytes.length);
    }

    /**
     * HOT deployの対象のクラスかどうか返します。
     *
     * @param className クラス名
     * @return HOT deployの対象のクラスかどうか
     */
    protected abstract boolean isTargetClass(String className);
}
