]> source.dussan.org Git - jgit.git/blob
4c698d974a4113531584ac760dc06ccb33f2570b
[jgit.git] /
1 /*
2  * Copyright (C) 2021, Thomas Wolf <thomas.wolf@paranor.ch> and others
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Distribution License v. 1.0 which is available at
6  * https://www.eclipse.org/org/documents/edl-v10.php.
7  *
8  * SPDX-License-Identifier: BSD-3-Clause
9  */
10 package org.eclipse.jgit.internal.transport.sshd.agent.connector;
11
12 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.Sockets.AF_UNIX;
13 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.Sockets.DEFAULT_PROTOCOL;
14 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.Sockets.ENV_SSH_AUTH_SOCK;
15 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.Sockets.SOCK_STREAM;
16 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.UnixSockets.FD_CLOEXEC;
17 import static org.eclipse.jgit.internal.transport.sshd.agent.connector.UnixSockets.F_SETFD;
18
19 import java.io.IOException;
20 import java.nio.charset.StandardCharsets;
21 import java.text.MessageFormat;
22 import java.util.Arrays;
23 import java.util.concurrent.atomic.AtomicBoolean;
24
25 import org.apache.sshd.common.SshException;
26 import org.eclipse.jgit.transport.sshd.agent.AbstractConnector;
27 import org.eclipse.jgit.util.StringUtils;
28 import org.eclipse.jgit.util.SystemReader;
29 import org.slf4j.Logger;
30 import org.slf4j.LoggerFactory;
31
32 import com.sun.jna.LastErrorException;
33 import com.sun.jna.Native;
34 import com.sun.jna.platform.unix.LibCAPI;
35
36 /**
37  * JNA-based implementation of communication through a Unix domain socket.
38  */
39 public class UnixDomainSocketConnector extends AbstractConnector {
40
41         private static final Logger LOG = LoggerFactory
42                         .getLogger(UnixDomainSocketConnector.class);
43
44         private static UnixSockets library;
45
46         private static boolean libraryLoaded = false;
47
48         private static synchronized UnixSockets getLibrary() {
49                 if (!libraryLoaded) {
50                         libraryLoaded = true;
51                         try {
52                                 library = Native.load(UnixSockets.LIBRARY_NAME, UnixSockets.class);
53                         } catch (Exception | UnsatisfiedLinkError
54                                         | NoClassDefFoundError e) {
55                                 LOG.error(Texts.get().logErrorLoadLibrary, e);
56                         }
57                 }
58                 return library;
59         }
60
61         private final String socketFile;
62
63         private AtomicBoolean connected = new AtomicBoolean();
64
65         private volatile int socketFd = -1;
66
67         /**
68          * Creates a new instance.
69          *
70          * @param socketFile
71          *            to use; if {@code null} or empty, use environment variable
72          *            SSH_AUTH_SOCK
73          */
74         public UnixDomainSocketConnector(String socketFile) {
75                 super();
76                 String file = socketFile;
77                 if (StringUtils.isEmptyOrNull(file)) {
78                         file = SystemReader.getInstance().getenv(ENV_SSH_AUTH_SOCK);
79                 }
80                 this.socketFile = file;
81         }
82
83         @Override
84         public boolean connect() throws IOException {
85                 if (StringUtils.isEmptyOrNull(socketFile)) {
86                         return false;
87                 }
88                 int fd = socketFd;
89                 synchronized (this) {
90                         if (connected.get()) {
91                                 return true;
92                         }
93                         UnixSockets sockets = getLibrary();
94                         if (sockets == null) {
95                                 return false;
96                         }
97                         try {
98                                 fd = sockets.socket(AF_UNIX, SOCK_STREAM, DEFAULT_PROTOCOL);
99                                 // OS X apparently doesn't have SOCK_CLOEXEC, so we can't set it
100                                 // atomically. Set it via fcntl, which exists on all systems
101                                 // we're interested in.
102                                 sockets.fcntl(fd, F_SETFD, FD_CLOEXEC);
103                                 Sockets.SockAddr sockAddr = new Sockets.SockAddr(socketFile,
104                                                 StandardCharsets.UTF_8);
105                                 sockets.connect(fd, sockAddr, sockAddr.size());
106                                 connected.set(true);
107                         } catch (LastErrorException e) {
108                                 if (fd >= 0) {
109                                         try {
110                                                 sockets.close(fd);
111                                         } catch (LastErrorException e1) {
112                                                 e.addSuppressed(e1);
113                                         }
114                                 }
115                                 throw new IOException(MessageFormat
116                                                 .format(Texts.get().msgConnectFailed, socketFile), e);
117                         }
118                 }
119                 socketFd = fd;
120                 return connected.get();
121         }
122
123         @Override
124         public synchronized void close() throws IOException {
125                 int fd = socketFd;
126                 if (connected.getAndSet(false) && fd >= 0) {
127                         socketFd = -1;
128                         try {
129                                 getLibrary().close(fd);
130                         } catch (LastErrorException e) {
131                                 throw new IOException(MessageFormat.format(
132                                                 Texts.get().msgCloseFailed, Integer.toString(fd)), e);
133                         }
134                 }
135         }
136
137         @Override
138         public byte[] rpc(byte command, byte[] message) throws IOException {
139                 prepareMessage(command, message);
140                 int fd = socketFd;
141                 if (!connected.get() || fd < 0) {
142                         // No translation, internal error
143                         throw new IllegalStateException("Not connected to SSH agent"); //$NON-NLS-1$
144                 }
145                 writeFully(fd, message);
146                 // Now receive the reply
147                 byte[] lengthBuf = new byte[4];
148                 readFully(fd, lengthBuf);
149                 int length = toLength(command, lengthBuf);
150                 byte[] payload = new byte[length];
151                 readFully(fd, payload);
152                 return payload;
153         }
154
155         private void writeFully(int fd, byte[] message) throws IOException {
156                 int toWrite = message.length;
157                 try {
158                         byte[] buf = message;
159                         while (toWrite > 0) {
160                                 int written = getLibrary()
161                                                 .write(fd, buf, new LibCAPI.size_t(buf.length))
162                                                 .intValue();
163                                 if (written < 0) {
164                                         throw new IOException(
165                                                         MessageFormat.format(Texts.get().msgSendFailed,
166                                                                         Integer.toString(message.length),
167                                                                         Integer.toString(toWrite)));
168                                 }
169                                 toWrite -= written;
170                                 if (written > 0 && toWrite > 0) {
171                                         buf = Arrays.copyOfRange(buf, written, buf.length);
172                                 }
173                         }
174                 } catch (LastErrorException e) {
175                         throw new IOException(
176                                         MessageFormat.format(Texts.get().msgSendFailed,
177                                                         Integer.toString(message.length),
178                                                         Integer.toString(toWrite)),
179                                         e);
180                 }
181         }
182
183         private void readFully(int fd, byte[] data) throws IOException {
184                 int n = 0;
185                 int offset = 0;
186                 while (offset < data.length
187                                 && (n = read(fd, data, offset, data.length - offset)) > 0) {
188                         offset += n;
189                 }
190                 if (offset < data.length) {
191                         throw new SshException(
192                                         MessageFormat.format(Texts.get().msgShortRead,
193                                                         Integer.toString(data.length),
194                                                         Integer.toString(offset), Integer.toString(n)));
195                 }
196         }
197
198         private int read(int fd, byte[] buffer, int offset, int length)
199                         throws IOException {
200                 try {
201                         LibCAPI.size_t toRead = new LibCAPI.size_t(length);
202                         if (offset == 0) {
203                                 return getLibrary().read(fd, buffer, toRead).intValue();
204                         }
205                         byte[] data = new byte[length];
206                         int read = getLibrary().read(fd, data, toRead).intValue();
207                         if (read > 0) {
208                                 System.arraycopy(data, 0, buffer, offset, read);
209                         }
210                         return read;
211                 } catch (LastErrorException e) {
212                         throw new IOException(
213                                         MessageFormat.format(Texts.get().msgReadFailed,
214                                                         Integer.toString(length)),
215                                         e);
216                 }
217         }
218 }