From de4a8d4f35ccb9918326fb1b893b62c68e1967c5 Mon Sep 17 00:00:00 2001
From: Alban Crequy <alban.crequy@collabora.co.uk>
Date: Wed, 16 Jul 2014 16:36:52 +0100
Subject: [PATCH 3/3] enforce new limit max_connections_per_cgroup

When it is not possible to get the Unix pid of a connection, that connection
bypass the limit. It can happen on Windows or on the non-supported case where
dbus-daemon is configured to listen on TCP rather than a Unix socket.
---
 bus/connection.c | 175 ++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 166 insertions(+), 9 deletions(-)

diff --git a/bus/connection.c b/bus/connection.c
index ea2d155..e064a0c 100644
--- a/bus/connection.c
+++ b/bus/connection.c
@@ -59,6 +59,7 @@ struct BusConnections
   int n_incomplete;     /**< Length of incomplete list */
   BusContext *context;
   DBusHashTable *completed_by_user; /**< Number of completed connections for each UID */
+  DBusHashTable *completed_by_cgroup; /**< Number of completed connections for each cgroup */
   DBusTimeout *expire_timeout; /**< Timeout for expiring incomplete connections. */
   int stamp;                   /**< Incrementing number */
   BusExpireList *pending_replies; /**< List of pending replies */
@@ -90,6 +91,8 @@ typedef struct
   DBusMessage *oom_message;
   DBusPreallocatedSend *oom_preallocated;
   BusClientPolicy *policy;
+  dbus_bool_t tracking_cgroup;
+  DBusString cgroup;
 
   char *cached_loginfo_string;
   BusSELinuxID *selinux_id;
@@ -127,6 +130,23 @@ connection_get_loop (DBusConnection *connection)
 
 
 static int
+get_connections_for_cgroup (BusConnections *connections,
+                            DBusString     *cgroup)
+{
+  void *val;
+  int current_count;
+
+  /* val is NULL is 0 when it isn't in the hash yet */
+
+  val = _dbus_hash_table_lookup_string (connections->completed_by_cgroup,
+                                        _dbus_string_get_data (cgroup));
+
+  current_count = _DBUS_POINTER_TO_INT (val);
+
+  return current_count;
+}
+
+static int
 get_connections_for_uid (BusConnections *connections,
                          dbus_uid_t      uid)
 {
@@ -144,6 +164,72 @@ get_connections_for_uid (BusConnections *connections,
 }
 
 static dbus_bool_t
+adjust_connections_for_cgroup (BusConnections *connections,
+                               DBusString     *cgroup,
+                               int             adjustment)
+{
+  int current_count;
+  char *cgroup_str;
+
+  cgroup_str = _dbus_string_get_data (cgroup);
+
+  current_count = get_connections_for_cgroup (connections, cgroup);
+
+  _dbus_verbose ("Adjusting connection count for cgroup %s: "
+                 "was %d adjustment %d making %d\n",
+                 cgroup_str, current_count, adjustment,
+                 current_count + adjustment);
+
+  _dbus_assert (current_count >= 0);
+
+  current_count += adjustment;
+
+  _dbus_assert (current_count >= 0);
+
+  if (current_count == 0)
+    {
+      _dbus_hash_table_remove_string (connections->completed_by_cgroup,
+                                      cgroup_str);
+      return TRUE;
+    }
+  else
+    {
+      dbus_bool_t retval;
+      DBusString copy;
+
+      /* duplicate the key */
+
+      if (!_dbus_string_init (&copy))
+        return FALSE;
+
+      if (!_dbus_string_copy (cgroup, 0, &copy, 0))
+        {
+          _dbus_string_free (&copy);
+          return FALSE;
+        }
+
+      if (!_dbus_string_steal_data (&copy, &cgroup_str))
+        {
+          _dbus_string_free (&copy);
+          return FALSE;
+        }
+
+      _dbus_string_free (&copy);
+
+      retval = _dbus_hash_table_insert_string (connections->completed_by_cgroup,
+                   cgroup_str, _DBUS_INT_TO_POINTER (current_count));
+
+      /* only positive adjustment can fail as otherwise
+       * a hash entry should already exist
+       */
+      _dbus_assert (adjustment > 0 ||
+                    (adjustment <= 0 && retval));
+
+      return retval;
+    }
+}
+
+static dbus_bool_t
 adjust_connections_for_uid (BusConnections *connections,
                             dbus_uid_t      uid,
                             int             adjustment)
@@ -286,6 +372,12 @@ bus_connection_disconnected (DBusConnection *connection)
               if (!adjust_connections_for_uid (d->connections,
                                                uid, -1))
                 _dbus_assert_not_reached ("adjusting downward should never fail");
+
+            }
+          if (d->tracking_cgroup)
+            {
+              if (!adjust_connections_for_cgroup (d->connections, &d->cgroup, -1))
+                _dbus_assert_not_reached ("adjusting downward should never fail");
             }
         }
       else
@@ -299,6 +391,8 @@ bus_connection_disconnected (DBusConnection *connection)
       _dbus_assert (d->connections->n_completed >= 0);
     }
 
+  _dbus_string_free (&d->cgroup);
+
   bus_connection_drop_pending_replies (d->connections, connection);
   
   /* frees "d" as side effect */
@@ -429,11 +523,16 @@ bus_connections_new (BusContext *context)
   if (connections->completed_by_user == NULL)
     goto failed_2;
 
+  connections->completed_by_cgroup = _dbus_hash_table_new (DBUS_HASH_STRING,
+                                                         dbus_free, NULL);
+  if (connections->completed_by_cgroup == NULL)
+    goto failed_3;
+
   connections->expire_timeout = _dbus_timeout_new (100, /* irrelevant */
                                                    expire_incomplete_timeout,
                                                    connections, NULL);
   if (connections->expire_timeout == NULL)
-    goto failed_3;
+    goto failed_4;
 
   _dbus_timeout_set_enabled (connections->expire_timeout, FALSE);
 
@@ -442,21 +541,23 @@ bus_connections_new (BusContext *context)
                                                       bus_pending_reply_expired,
                                                       connections);
   if (connections->pending_replies == NULL)
-    goto failed_4;
+    goto failed_5;
   
   if (!_dbus_loop_add_timeout (bus_context_get_loop (context),
                                connections->expire_timeout))
-    goto failed_5;
+    goto failed_6;
   
   connections->refcount = 1;
   connections->context = context;
   
   return connections;
 
- failed_5:
+ failed_6:
   bus_expire_list_free (connections->pending_replies);
- failed_4:
+ failed_5:
   _dbus_timeout_unref (connections->expire_timeout);
+ failed_4:
+  _dbus_hash_table_unref (connections->completed_by_cgroup);
  failed_3:
   _dbus_hash_table_unref (connections->completed_by_user);
  failed_2:
@@ -521,6 +622,7 @@ bus_connections_unref (BusConnections *connections)
       _dbus_timeout_unref (connections->expire_timeout);
       
       _dbus_hash_table_unref (connections->completed_by_user);
+      _dbus_hash_table_unref (connections->completed_by_cgroup);
       
       dbus_free (connections);
 
@@ -596,7 +698,6 @@ bus_connections_setup_connection (BusConnections *connections,
   BusConnectionData *d;
   dbus_bool_t retval;
   DBusError error;
-
   
   d = dbus_new0 (BusConnectionData, 1);
   
@@ -661,6 +762,9 @@ bus_connections_setup_connection (BusConnections *connections,
                                           allow_unix_user_function,
                                           NULL, NULL);
 
+  if (!_dbus_string_init (&d->cgroup))
+    goto out;
+
   dbus_connection_set_dispatch_status_function (connection,
                                                 dispatch_status_function,
                                                 bus_context_get_loop (connections->context),
@@ -738,6 +842,8 @@ bus_connections_setup_connection (BusConnections *connections,
       dbus_connection_set_unix_user_function (connection,
                                               NULL, NULL, NULL);
 
+      _dbus_string_free (&d->cgroup);
+
       dbus_connection_set_windows_user_function (connection,
                                                  NULL, NULL, NULL);
       
@@ -1358,6 +1464,7 @@ bus_connection_complete (DBusConnection   *connection,
 {
   BusConnectionData *d;
   unsigned long uid;
+  unsigned long pid;
   
   d = BUS_CONNECTION_DATA (connection);
   _dbus_assert (d != NULL);
@@ -1397,8 +1504,17 @@ bus_connection_complete (DBusConnection   *connection,
   
   if (dbus_connection_get_unix_user (connection, &uid))
     {
-      if (!adjust_connections_for_uid (d->connections,
-                                       uid, 1))
+      if (!adjust_connections_for_uid (d->connections, uid, 1))
+        goto fail;
+
+    }
+
+  if (d->tracking_cgroup)
+    {
+      if (!dbus_connection_get_unix_process_id (connection, &pid))
+        goto fail;
+
+      if (!adjust_connections_for_cgroup (d->connections, &d->cgroup, 1))
         goto fail;
     }
 
@@ -1492,6 +1608,13 @@ bus_connections_check_limits (BusConnections  *connections,
                               DBusError       *error)
 {
   unsigned long uid;
+  unsigned long pid;
+  int count;
+  BusConnectionData *d;
+
+  d = BUS_CONNECTION_DATA (requesting_completion);
+  _dbus_assert (d != NULL);
+
 
   if (connections->n_completed >=
       bus_context_get_max_completed_connections (connections->context))
@@ -1512,7 +1635,41 @@ bus_connections_check_limits (BusConnections  *connections,
           return FALSE;
         }
     }
-  
+
+  /* bus_connections_check_limits() could be called several times on the same
+   * connection. Make sure to be idempotent. */
+  if (!d->tracking_cgroup &&
+      dbus_connection_get_unix_process_id (requesting_completion, &pid))
+    {
+      d->tracking_cgroup = TRUE;
+
+      if (!_dbus_cgroup_for_pid (pid, &d->cgroup, 4096, error))
+        {
+          _dbus_verbose ("Couldn't get connection's cgroup\n");
+          dbus_error_free (error);
+          /* ignore the error: the connection will be accounted for cgroup ""
+           */
+        }
+    }
+
+  _dbus_verbose ("Checking max_connections_per_cgroup? %d\n", d->tracking_cgroup);
+  if (d->tracking_cgroup)
+    {
+      _dbus_verbose ("Checking max_connections_per_cgroup: current count=%d max=%d\n",
+                     get_connections_for_cgroup (connections, &d->cgroup),
+                     bus_context_get_max_connections_per_cgroup (connections->context));
+
+      count = get_connections_for_cgroup (connections, &d->cgroup);
+      if (count >=
+          bus_context_get_max_connections_per_cgroup (connections->context))
+        {
+          dbus_set_error (error, DBUS_ERROR_LIMITS_EXCEEDED,
+                          "The maximum number of active connections for cgroup %s has been reached",
+                          _dbus_string_get_data (&d->cgroup));
+          return FALSE;
+        }
+    }
+
   return TRUE;
 }
 
-- 
1.8.5.3