diff options
Diffstat (limited to 'src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java')
-rw-r--r-- | src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java | 430 |
1 files changed, 430 insertions, 0 deletions
diff --git a/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java b/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java new file mode 100644 index 00000000..fd73ccfd --- /dev/null +++ b/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java @@ -0,0 +1,430 @@ +// Copyright (C) 2009 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.gitblit.transport.ssh.commands; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.io.OutputStream; +import java.io.StringWriter; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.sshd.common.SshException; +import org.apache.sshd.server.Command; +import org.apache.sshd.server.Environment; +import org.apache.sshd.server.ExitCallback; +import org.kohsuke.args4j.Argument; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.gitblit.transport.ssh.AbstractSshCommand; +import com.gitblit.utils.IdGenerator; +import com.gitblit.utils.WorkQueue; +import com.gitblit.utils.cli.CmdLineParser; +import com.google.common.base.Charsets; +import com.google.common.util.concurrent.Atomics; + +public abstract class BaseCommand extends AbstractSshCommand { + private static final Logger log = LoggerFactory + .getLogger(BaseCommand.class); + + /** Text of the command line which lead up to invoking this instance. */ + private String commandName = ""; + + /** Unparsed command line options. */ + private String[] argv; + + /** The task, as scheduled on a worker thread. */ + private final AtomicReference<Future<?>> task; + + private final WorkQueue.Executor executor; + + public BaseCommand() { + task = Atomics.newReference(); + IdGenerator gen = new IdGenerator(); + WorkQueue w = new WorkQueue(gen); + this.executor = w.getDefaultQueue(); + } + + public void setInputStream(final InputStream in) { + this.in = in; + } + + public void setOutputStream(final OutputStream out) { + this.out = out; + } + + public void setErrorStream(final OutputStream err) { + this.err = err; + } + + public void setExitCallback(final ExitCallback callback) { + this.exit = callback; + } + + protected void provideStateTo(final Command cmd) { + cmd.setInputStream(in); + cmd.setOutputStream(out); + cmd.setErrorStream(err); + cmd.setExitCallback(exit); + } + + protected String getName() { + return commandName; + } + + void setName(final String prefix) { + this.commandName = prefix; + } + + public String[] getArguments() { + return argv; + } + + public void setArguments(final String[] argv) { + this.argv = argv; + } + + /** + * Parses the command line argument, injecting parsed values into fields. + * <p> + * This method must be explicitly invoked to cause a parse. + * + * @throws UnloggedFailure if the command line arguments were invalid. + * @see Option + * @see Argument + */ + protected void parseCommandLine() throws UnloggedFailure { + parseCommandLine(this); + } + + /** + * Parses the command line argument, injecting parsed values into fields. + * <p> + * This method must be explicitly invoked to cause a parse. + * + * @param options object whose fields declare Option and Argument annotations + * to describe the parameters of the command. Usually {@code this}. + * @throws UnloggedFailure if the command line arguments were invalid. + * @see Option + * @see Argument + */ + protected void parseCommandLine(Object options) throws UnloggedFailure { + final CmdLineParser clp = newCmdLineParser(options); + try { + clp.parseArgument(argv); + } catch (IllegalArgumentException err) { + if (!clp.wasHelpRequestedByOption()) { + throw new UnloggedFailure(1, "fatal: " + err.getMessage()); + } + } catch (CmdLineException err) { + if (!clp.wasHelpRequestedByOption()) { + throw new UnloggedFailure(1, "fatal: " + err.getMessage()); + } + } + + if (clp.wasHelpRequestedByOption()) { + StringWriter msg = new StringWriter(); + clp.printDetailedUsage(commandName, msg); + msg.write(usage()); + throw new UnloggedFailure(1, msg.toString()); + } + } + + /** Construct a new parser for this command's received command line. */ + protected CmdLineParser newCmdLineParser(Object options) { + return new CmdLineParser(options); + } + + protected String usage() { + return ""; + } + + private final class TaskThunk implements com.gitblit.utils.WorkQueue.CancelableRunnable { + private final CommandRunnable thunk; + private final String taskName; + + private TaskThunk(final CommandRunnable thunk) { + this.thunk = thunk; + + // TODO +// StringBuilder m = new StringBuilder("foo"); +// m.append(context.getCommandLine()); +// if (userProvider.get().isIdentifiedUser()) { +// IdentifiedUser u = (IdentifiedUser) userProvider.get(); +// m.append(" (").append(u.getAccount().getUserName()).append(")"); +// } + this.taskName = "foo";//m.toString(); + } + + @Override + public void cancel() { + synchronized (this) { + //final Context old = sshScope.set(context); + try { + //onExit(/*STATUS_CANCEL*/); + } finally { + //sshScope.set(old); + } + } + } + + @Override + public void run() { + synchronized (this) { + final Thread thisThread = Thread.currentThread(); + final String thisName = thisThread.getName(); + int rc = 0; + //final Context old = sshScope.set(context); + try { + //context.started = TimeUtil.nowMs(); + thisThread.setName("SSH " + taskName); + + thunk.run(); + + out.flush(); + err.flush(); + } catch (Throwable e) { + try { + out.flush(); + } catch (Throwable e2) { + } + try { + err.flush(); + } catch (Throwable e2) { + } + rc = handleError(e); + } finally { + try { + onExit(rc); + } finally { + thisThread.setName(thisName); + } + } + } + } + + @Override + public String toString() { + return taskName; + } + } + + /** Runnable function which can throw an exception. */ + public static interface CommandRunnable { + public void run() throws Exception; + } + + + /** + * Spawn a function into its own thread. + * <p> + * Typically this should be invoked within {@link Command#start(Environment)}, + * such as: + * + * <pre> + * startThread(new Runnable() { + * public void run() { + * runImp(); + * } + * }); + * </pre> + * + * @param thunk the runnable to execute on the thread, performing the + * command's logic. + */ + protected void startThread(final Runnable thunk) { + startThread(new CommandRunnable() { + @Override + public void run() throws Exception { + thunk.run(); + } + }); + } + + /** + * Terminate this command and return a result code to the remote client. + * <p> + * Commands should invoke this at most once. Once invoked, the command may + * lose access to request based resources as any callbacks previously + * registered with {@link RequestCleanup} will fire. + * + * @param rc exit code for the remote client. + */ + protected void onExit(final int rc) { + exit.onExit(rc); +// if (cleanup != null) { +// cleanup.run(); +// } + } + + private int handleError(final Throwable e) { + if ((e.getClass() == IOException.class + && "Pipe closed".equals(e.getMessage())) + || // + (e.getClass() == SshException.class + && "Already closed".equals(e.getMessage())) + || // + e.getClass() == InterruptedIOException.class) { + // This is sshd telling us the client just dropped off while + // we were waiting for a read or a write to complete. Either + // way its not really a fatal error. Don't log it. + // + return 127; + } + + if (e instanceof UnloggedFailure) { + } else { + final StringBuilder m = new StringBuilder(); + m.append("Internal server error"); +// if (userProvider.get().isIdentifiedUser()) { +// final IdentifiedUser u = (IdentifiedUser) userProvider.get(); +// m.append(" (user "); +// m.append(u.getAccount().getUserName()); +// m.append(" account "); +// m.append(u.getAccountId()); +// m.append(")"); +// } +// m.append(" during "); +// m.append(contextProvider.get().getCommandLine()); + log.error(m.toString(), e); + } + + if (e instanceof Failure) { + final Failure f = (Failure) e; + try { + err.write((f.getMessage() + "\n").getBytes(Charsets.UTF_8)); + err.flush(); + } catch (IOException e2) { + } catch (Throwable e2) { + log.warn("Cannot send failure message to client", e2); + } + return f.exitCode; + + } else { + try { + err.write("fatal: internal server error\n".getBytes(Charsets.UTF_8)); + err.flush(); + } catch (IOException e2) { + } catch (Throwable e2) { + log.warn("Cannot send internal server error message to client", e2); + } + return 128; + } + } + + /** + * Spawn a function into its own thread. + * <p> + * Typically this should be invoked within {@link Command#start(Environment)}, + * such as: + * + * <pre> + * startThread(new CommandRunnable() { + * public void run() throws Exception { + * runImp(); + * } + * }); + * </pre> + * <p> + * If the function throws an exception, it is translated to a simple message + * for the client, a non-zero exit code, and the stack trace is logged. + * + * @param thunk the runnable to execute on the thread, performing the + * command's logic. + */ + protected void startThread(final CommandRunnable thunk) { + final TaskThunk tt = new TaskThunk(thunk); + task.set(executor.submit(tt)); + } + + /** Thrown from {@link CommandRunnable#run()} with client message and code. */ + public static class Failure extends Exception { + private static final long serialVersionUID = 1L; + + final int exitCode; + + /** + * Create a new failure. + * + * @param exitCode exit code to return the client, which indicates the + * failure status of this command. Should be between 1 and 255, + * inclusive. + * @param msg message to also send to the client's stderr. + */ + public Failure(final int exitCode, final String msg) { + this(exitCode, msg, null); + } + + /** + * Create a new failure. + * + * @param exitCode exit code to return the client, which indicates the + * failure status of this command. Should be between 1 and 255, + * inclusive. + * @param msg message to also send to the client's stderr. + * @param why stack trace to include in the server's log, but is not sent to + * the client's stderr. + */ + public Failure(final int exitCode, final String msg, final Throwable why) { + super(msg, why); + this.exitCode = exitCode; + } + } + + /** Thrown from {@link CommandRunnable#run()} with client message and code. */ + public static class UnloggedFailure extends Failure { + private static final long serialVersionUID = 1L; + + /** + * Create a new failure. + * + * @param msg message to also send to the client's stderr. + */ + public UnloggedFailure(final String msg) { + this(1, msg); + } + + /** + * Create a new failure. + * + * @param exitCode exit code to return the client, which indicates the + * failure status of this command. Should be between 1 and 255, + * inclusive. + * @param msg message to also send to the client's stderr. + */ + public UnloggedFailure(final int exitCode, final String msg) { + this(exitCode, msg, null); + } + + /** + * Create a new failure. + * + * @param exitCode exit code to return the client, which indicates the + * failure status of this command. Should be between 1 and 255, + * inclusive. + * @param msg message to also send to the client's stderr. + * @param why stack trace to include in the server's log, but is not sent to + * the client's stderr. + */ + public UnloggedFailure(final int exitCode, final String msg, + final Throwable why) { + super(exitCode, msg, why); + } + } +} |