Parcourir la source

lib: support the new pidfd based APIs

Try to make API requests using the new pidfd based APIs. If getting
the pidfds fails or if the remote (daemon) does not support the new
pidfd based D-Bus API, transparently fall back to the old API.
Christian Kellner il y a 5 ans
Parent
commit
6f7df91b60
2 fichiers modifiés avec 124 ajouts et 26 suppressions
  1. 123 26
      lib/client_impl.c
  2. 1 0
      lib/meson.build

+ 123 - 26
lib/client_impl.c

@@ -31,8 +31,14 @@ POSSIBILITY OF SUCH DAMAGE.
 
 #define _GNU_SOURCE
 
+#include <common-helpers.h>
+#include <common-pidfds.h>
+
 #include <dbus/dbus.h>
+#include <errno.h>
 #include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
 #include <sys/stat.h>
 #include <sys/types.h>
 #include <unistd.h>
@@ -54,6 +60,7 @@ POSSIBILITY OF SUCH DAMAGE.
 #define _cleanup_bus_ _cleanup_(hop_off_the_bus)
 #define _cleanup_msg_ _cleanup_(cleanup_msg)
 #define _cleanup_dpc_ _cleanup_(cleanup_pending_call)
+#define _cleanup_fds_ _cleanup_(cleanup_fd_array)
 
 #ifdef NDEBUG
 #define DEBUG(...)
@@ -73,6 +80,35 @@ static int log_error(const char *fmt, ...) __attribute__((format(printf, 1, 2)))
 // Storage for error strings
 static char error_string[512] = { 0 };
 
+// memory helpers
+static void cleanup_fd_array(int **fdlist)
+{
+	if (fdlist == NULL || *fdlist == NULL)
+		return;
+
+	int errsave = errno;
+	for (int *fd = *fdlist; *fd != -1; fd++) {
+		TRACE("GM Closing fd %d\n", *fd);
+		(void)close(*fd);
+	}
+
+	errno = errsave;
+	free(*fdlist);
+}
+
+// Allocate a -1 termianted array of ints
+static inline int *alloc_fd_array(int n)
+{
+	int *fds;
+
+	size_t count = (size_t)n + 1; /* -1, terminated */
+	fds = (int *)malloc(sizeof(int) * count);
+	for (size_t i = 0; i < count; i++)
+		fds[i] = -1;
+
+	return fds;
+}
+
 // Helper to check if we are running inside a flatpak
 static int in_flatpak(void)
 {
@@ -136,7 +172,7 @@ static DBusConnection *hop_on_the_bus(void)
 /* cleanup functions */
 static void cleanup_msg(DBusMessage **msg)
 {
-	if (msg == NULL)
+	if (msg == NULL || *msg == NULL)
 		return;
 
 	dbus_message_unref(*msg);
@@ -144,24 +180,54 @@ static void cleanup_msg(DBusMessage **msg)
 
 static void cleanup_pending_call(DBusPendingCall **call)
 {
-	if (call == NULL)
+	if (call == NULL || *call == NULL)
 		return;
 
 	dbus_pending_call_unref(*call);
 }
 
 /* internal API */
-static int make_request(DBusConnection *bus,
-			int native, const char *method,
-			pid_t *pids, int npids,
-			DBusError *error)
+static int make_request(DBusConnection *bus, int native, int use_pidfds, const char *method,
+                        pid_t *pids, int npids, DBusError *error)
 {
 	_cleanup_msg_ DBusMessage *msg = NULL;
 	_cleanup_dpc_ DBusPendingCall *call = NULL;
+	_cleanup_fds_ int *fds = NULL;
+	char action[256] = {
+		0,
+	};
 	DBusError err;
 	DBusMessageIter iter;
 	int res = -1;
 
+	TRACE("GM: Incoming request: %s, npids: %d, native: %d pifds: %d\n",
+	      method,
+	      npids,
+	      native,
+	      use_pidfds);
+
+	if (use_pidfds) {
+		fds = alloc_fd_array(npids);
+
+		res = open_pidfds(pids, fds, npids);
+		if (res != npids) {
+			dbus_set_error(error, DBUS_ERROR_FAILED, "Could not open pidfd for %d", (int)pids[res]);
+			return -1;
+		}
+
+		if (strstr(method, "ByPID"))
+			snprintf(action, sizeof(action), "%sFd", method);
+		else
+			snprintf(action, sizeof(action), "%sByPIDFd", method);
+		method = action;
+	}
+
+	TRACE("GM:   Making request: %s, npids: %d, native: %d pifds: %d\n",
+	      method,
+	      npids,
+	      native,
+	      use_pidfds);
+
 	// If we are inside a flatpak we need to talk to the portal instead
 	const char *dest = native ? DAEMON_DBUS_NAME : PORTAL_DBUS_NAME;
 	const char *path = native ? DAEMON_DBUS_PATH : PORTAL_DBUS_PATH;
@@ -170,15 +236,24 @@ static int make_request(DBusConnection *bus,
 	msg = dbus_message_new_method_call(dest, path, iface, method);
 
 	if (!msg) {
-		dbus_set_error_const(error, DBUS_ERROR_FAILED,
-				     "Could not create dbus message");
+		dbus_set_error_const(error, DBUS_ERROR_FAILED, "Could not create dbus message");
 		return -1;
 	}
 
 	dbus_message_iter_init_append(msg, &iter);
+
 	for (int i = 0; i < npids; i++) {
-		dbus_int32_t p = (dbus_int32_t)pids[i];
-		dbus_message_iter_append_basic(&iter, DBUS_TYPE_INT32, &p);
+		dbus_int32_t p;
+		int type;
+
+		if (use_pidfds) {
+			type = DBUS_TYPE_UNIX_FD;
+			p = (dbus_int32_t)fds[i];
+		} else {
+			type = DBUS_TYPE_INT32;
+			p = (dbus_int32_t)pids[i];
+		}
+		dbus_message_iter_append_basic(&iter, type, &p);
 	}
 
 	dbus_connection_send_with_reply(bus, msg, &call, -1);
@@ -190,21 +265,22 @@ static int make_request(DBusConnection *bus,
 	msg = dbus_pending_call_steal_reply(call);
 
 	if (msg == NULL) {
-		dbus_set_error_const(error, DBUS_ERROR_FAILED,
-				     "Did not receive a reply");
+		dbus_set_error_const(error, DBUS_ERROR_FAILED, "Did not receive a reply");
 		return -1;
 	}
 
 	dbus_error_init(&err);
-
+	res = -1;
 	if (dbus_set_error_from_message(&err, msg)) {
-		dbus_set_error(error, err.name,
-			       "Could not call method '%s' on '%s': %s",
-			       method, dest, err.message);
+		dbus_set_error(error,
+		               err.name,
+		               "Could not call method '%s' on '%s': %s",
+		               method,
+		               dest,
+		               err.message);
 	} else if (!dbus_message_iter_init(msg, &iter) ||
-		   dbus_message_iter_get_arg_type(&iter) != DBUS_TYPE_INT32) {
-		dbus_set_error(error, DBUS_ERROR_INVALID_SIGNATURE,
-			       "Failed to parse response");
+	           dbus_message_iter_get_arg_type(&iter) != DBUS_TYPE_INT32) {
+		dbus_set_error(error, DBUS_ERROR_INVALID_SIGNATURE, "Failed to parse response");
 	} else {
 		dbus_message_iter_get_basic(&iter, &res);
 	}
@@ -219,16 +295,26 @@ static int make_request(DBusConnection *bus,
 static int gamemode_request(const char *method, pid_t for_pid)
 {
 	_cleanup_bus_ DBusConnection *bus = NULL;
+	static int use_pidfs = 1;
 	DBusError err;
-	pid_t pids[2] = {0, for_pid};
+	pid_t pids[2];
 	int npids;
 	int native;
 	int res = -1;
 
 	native = !in_flatpak();
-	pids[0] = getpid();
 
-	TRACE("GM: [%d] request '%s' received (for pid: %d) [portal: %s]\n",
+	/* pid[0] is the client, i.e. the game
+	 * pid[1] is the requestor, i.e. this process
+	 *
+	 * we setup the array such that pids[1] will always be a valid
+	 * pid, because if we are going to use the pidfd based API,
+	 * both pids are being sent, even if they are the same
+	 */
+	pids[1] = getpid();
+	pids[0] = for_pid != 0 ? for_pid : pids[1];
+
+	TRACE("GM: [%d] request '%s' received (by: %d) [portal: %s]\n",
 	      (int)pids[0],
 	      method,
 	      (int)pids[1],
@@ -239,12 +325,23 @@ static int gamemode_request(const char *method, pid_t for_pid)
 	if (bus == NULL)
 		return -1;
 
-	npids = for_pid > 0 ? 2 : 1;
-
 	dbus_error_init(&err);
-	res = make_request(bus, native, method, pids, npids, &err);
+retry:
+	if (for_pid != 0 || use_pidfs)
+		npids = 2;
+	else
+		npids = 1;
+
+	res = make_request(bus, native, use_pidfs, method, pids, npids, &err);
+
+	if (res == -1 && use_pidfs && dbus_error_is_set(&err)) {
+		TRACE("GM: Request with pidfds failed (%s). Retrying.\n", err.message);
+		use_pidfs = 0;
+		dbus_error_free(&err);
+		goto retry;
+	}
 
-	if (res == -1)
+	if (res == -1 && dbus_error_is_set(&err))
 		log_error("D-Bus error: %s", err.message);
 
 	TRACE("GM: [%d] request '%s' done: %d\n", (int)pids[0], method, res);

+ 1 - 0
lib/meson.build

@@ -12,6 +12,7 @@ gamemode = shared_library(
         'client_impl.c',
     ],
     dependencies: [
+        link_lib_common,
         dep_dbus,
     ],
     install: true,