]> diplodocus.org Git - nmh/blobdiff - sbr/oauth.c
Fix tests with oauth disabled.
[nmh] / sbr / oauth.c
index b795eb815e7df675ef71a7e9f2801c1e69a21bbe..57d4bbbabad9b10abed010078fcf7c03c00f1523 100644 (file)
@@ -108,6 +108,9 @@ struct mh_oauth_cred {
     /* Ignoring token_type ([1] 7.1) because
      * https://developers.google.com/accounts/docs/OAuth2InstalledApp says
      * "Currently, this field always has the value Bearer". */
+
+    /* only filled while loading cred files, otherwise NULL */
+    char *user;
 };
 
 struct mh_oauth_ctx {
@@ -192,7 +195,7 @@ mh_oauth_do_xoauth(const char *user, const char *svc, FILE *log)
         adios(fn, "failed to lock");
     }
 
-    if ((cred = mh_oauth_cred_load(fp, ctx)) == NULL) {
+    if ((cred = mh_oauth_cred_load(fp, ctx, user)) == NULL) {
         adios(NULL, mh_oauth_get_err_string(ctx));
     }
 
@@ -209,7 +212,7 @@ mh_oauth_do_xoauth(const char *user, const char *svc, FILE *log)
         }
 
         fseek(fp, 0, SEEK_SET);
-        if (!mh_oauth_cred_save(fp, cred)) {
+        if (!mh_oauth_cred_save(fp, cred, user)) {
             adios(NULL, mh_oauth_get_err_string(ctx));
         }
     }
@@ -262,7 +265,7 @@ set_err_http(mh_oauth_ctx *ctx, const struct curl_ctx *curl_ctx)
     /* 5.2. Error Response says error response should use status code 400 and
      * application/json body.  If Content-Type matches, try to parse the body
      * regardless of the status code. */
-    if (curl_ctx->res_body != NULL
+    if (curl_ctx->res_len > 0
         && is_json(curl_ctx->content_type)
         && get_json_strings(curl_ctx->res_body, curl_ctx->res_len, ctx->log,
                             "error", &error, (void *)NULL)
@@ -487,6 +490,9 @@ mh_oauth_get_err_string(mh_oauth_ctx *ctx)
     case MH_OAUTH_NO_REFRESH:
         base = "no refresh token";
         break;
+    case MH_OAUTH_CRED_USER_NOT_FOUND:
+        base = "user not found in cred file";
+        break;
     case MH_OAUTH_CRED_FILE:
         base = "error loading cred file";
         break;
@@ -716,49 +722,92 @@ mh_oauth_cred_fn(mh_oauth_ctx *ctx)
     return ctx->cred_fn = result;
 }
 
-boolean
-mh_oauth_cred_save(FILE *fp, mh_oauth_cred *cred)
+/* for loading multi-user cred files */
+struct user_creds {
+    mh_oauth_cred *creds;
+
+    /* number of allocated mh_oauth_cred structs above points to */
+    size_t alloc;
+
+    /* number that are actually filled in and used */
+    size_t len;
+};
+
+/* If user has an entry in user_creds, return pointer to it.  Else allocate a
+ * new struct in user_creds and return pointer to that. */
+static mh_oauth_cred *
+find_or_alloc_user_creds(struct user_creds user_creds[], const char *user)
 {
-    int fd = fileno(fp);
-    if (fchmod(fd, S_IRUSR | S_IWUSR) < 0) goto err;
-    if (ftruncate(fd, 0) < 0) goto err;
-    if (cred->access_token != NULL) {
-        if (fprintf(fp, "access: %s\n", cred->access_token) < 0) goto err;
-    }
-    if (cred->refresh_token != NULL) {
-        if (fprintf(fp, "refresh: %s\n", cred->refresh_token) < 0) goto err;
+    mh_oauth_cred *creds = user_creds->creds;
+    size_t i;
+    for (i = 0; i < user_creds->len; i++) {
+        if (strcmp(creds[i].user, user) == 0) {
+            return &creds[i];
+        }
     }
-    if (cred->expires_at > 0) {
-        if (fprintf(fp, "expire: %ld\n", (long)cred->expires_at) < 0) goto err;
+    if (user_creds->alloc == user_creds->len) {
+        user_creds->alloc *= 2;
+        user_creds->creds = mh_xrealloc(user_creds->creds, user_creds->alloc);
     }
-    return TRUE;
+    creds = user_creds->creds+user_creds->len;
+    user_creds->len++;
+    creds->user = getcpy(user);
+    creds->access_token = creds->refresh_token = NULL;
+    creds->expires_at = 0;
+    return creds;
+}
 
-  err:
-    set_err(cred->ctx, MH_OAUTH_CRED_FILE);
-    return FALSE;
+static void
+free_user_creds(struct user_creds *user_creds)
+{
+    mh_oauth_cred *cred;
+    size_t i;
+    cred = user_creds->creds;
+    for (i = 0; i < user_creds->len; i++) {
+        free(cred[i].user);
+        free(cred[i].access_token);
+        free(cred[i].refresh_token);
+    }
+    free(user_creds->creds);
+    free(user_creds);
 }
 
 static boolean
-parse_cred(char **access, char **refresh, char **expire, FILE *fp,
-           mh_oauth_ctx *ctx)
+load_creds(struct user_creds **result, FILE *fp, mh_oauth_ctx *ctx)
 {
-    boolean result = FALSE;
+    boolean success = FALSE;
     char name[NAMESZ], value_buf[BUFSIZ];
     int state;
     m_getfld_state_t getfld_ctx = 0;
 
+    struct user_creds *user_creds = mh_xmalloc(sizeof *user_creds);
+    user_creds->alloc = 4;
+    user_creds->len = 0;
+    user_creds->creds = mh_xmalloc(user_creds->alloc * sizeof *user_creds->creds);
+
     for (;;) {
        int size = sizeof value_buf;
        switch (state = m_getfld(&getfld_ctx, name, value_buf, &size, fp)) {
         case FLD:
         case FLDPLUS: {
-            char **save;
-            if (strcmp(name, "access") == 0) {
-                save = access;
-            } else if (strcmp(name, "refresh") == 0) {
-                save = refresh;
-            } else if (strcmp(name, "expire") == 0) {
-                save = expire;
+            char **save, *expire;
+            time_t *expires_at = NULL;
+            if (strncmp(name, "access-", 7) == 0) {
+                const char *user = name + 7;
+                mh_oauth_cred *creds = find_or_alloc_user_creds(user_creds,
+                                                                user);
+                save = &creds->access_token;
+            } else if (strncmp(name, "refresh-", 8) == 0) {
+                const char *user = name + 8;
+                mh_oauth_cred *creds = find_or_alloc_user_creds(user_creds,
+                                                                user);
+                save = &creds->refresh_token;
+            } else if (strncmp(name, "expire-", 7) == 0) {
+                const char *user = name + 7;
+                mh_oauth_cred *creds = find_or_alloc_user_creds(user_creds,
+                                                                user);
+                expires_at = &creds->expires_at;
+                save = &expire;
             } else {
                 set_err_details(ctx, MH_OAUTH_CRED_FILE, "unexpected field");
                 break;
@@ -776,12 +825,23 @@ parse_cred(char **access, char **refresh, char **expire, FILE *fp,
                 *save = trimcpy(tmp);
                 free(tmp);
             }
+            if (expires_at != NULL) {
+                errno = 0;
+                *expires_at = strtol(expire, NULL, 10);
+                free(expire);
+                if (errno != 0) {
+                    set_err_details(ctx, MH_OAUTH_CRED_FILE,
+                                    "invalid expiration time");
+                    break;
+                }
+                expires_at = NULL;
+            }
             continue;
         }
 
         case BODY:
         case FILEEOF:
-            result = TRUE;
+            success = TRUE;
             break;
 
         default:
@@ -793,41 +853,114 @@ parse_cred(char **access, char **refresh, char **expire, FILE *fp,
        break;
     }
     m_getfld_state_destroy(&getfld_ctx);
-    return result;
+
+    if (success) {
+        *result = user_creds;
+    } else {
+        free_user_creds(user_creds);
+    }
+
+    return success;
+}
+
+static boolean
+save_user(FILE *fp, const char *user, const char *access, const char *refresh,
+          long expires_at)
+{
+    if (access != NULL) {
+        if (fprintf(fp, "access-%s: %s\n", user, access) < 0) return FALSE;
+    }
+    if (refresh != NULL) {
+        if (fprintf(fp, "refresh-%s: %s\n", user, refresh) < 0) return FALSE;
+    }
+    if (expires_at > 0) {
+        if (fprintf(fp, "expire-%s: %ld\n", user, (long)expires_at) < 0) {
+            return FALSE;
+        }
+    }
+    return TRUE;
+}
+
+boolean
+mh_oauth_cred_save(FILE *fp, mh_oauth_cred *cred, const char *user)
+{
+    struct user_creds *user_creds;
+    int fd = fileno(fp);
+    size_t i;
+
+    /* Load existing creds if any. */
+    if (!load_creds(&user_creds, fp, cred->ctx)) {
+        return FALSE;
+    }
+
+    if (fchmod(fd, S_IRUSR | S_IWUSR) < 0) goto err;
+    if (ftruncate(fd, 0) < 0) goto err;
+    if (fseek(fp, 0, SEEK_SET) < 0) goto err;
+
+    /* Write all creds except for this user. */
+    for (i = 0; i < user_creds->len; i++) {
+        mh_oauth_cred *c = &user_creds->creds[i];
+        if (strcmp(c->user, user) == 0) continue;
+        if (!save_user(fp, c->user, c->access_token, c->refresh_token,
+                       c->expires_at)) {
+            goto err;
+        }
+    }
+
+    /* Write updated creds for this user. */
+    if (!save_user(fp, user, cred->access_token, cred->refresh_token,
+                   cred->expires_at)) {
+        goto err;
+    }
+
+    free_user_creds(user_creds);
+
+    return TRUE;
+
+  err:
+    free_user_creds(user_creds);
+    set_err(cred->ctx, MH_OAUTH_CRED_FILE);
+    return FALSE;
 }
 
 mh_oauth_cred *
-mh_oauth_cred_load(FILE *fp, mh_oauth_ctx *ctx)
+mh_oauth_cred_load(FILE *fp, mh_oauth_ctx *ctx, const char *user)
 {
-    mh_oauth_cred *result;
-    time_t expires_at = 0;
-    char *access, *refresh, *expire;
-
-    access = refresh = expire = NULL;
-    if (!parse_cred(&access, &refresh, &expire, fp, ctx)) {
-        free(access);
-        free(refresh);
-        free(expire);
+    mh_oauth_cred *creds, *result = NULL;
+    struct user_creds *user_creds;
+    size_t i;
+
+    if (!load_creds(&user_creds, fp, ctx)) {
         return NULL;
     }
 
-    if (expire != NULL) {
-        errno = 0;
-        expires_at = strtol(expire, NULL, 10);
-        free(expire);
-        if (errno != 0) {
-            set_err_details(ctx, MH_OAUTH_CRED_FILE, "invalid expiration time");
-            free(access);
-            free(refresh);
-            return NULL;
+    /* Search user_creds for this user.  If we don't find it, return NULL.
+     * If we do, free fields of all structs except this one, moving this one to
+     * the first struct if necessary.  When we return it, it just looks like one
+     * struct to the caller, and the whole array is freed later. */
+    creds = user_creds->creds;
+    for (i = 0; i < user_creds->len; i++) {
+        if (strcmp(creds[i].user, user) == 0) {
+            result = creds;
+            if (i > 0) {
+                result->access_token = creds[i].access_token;
+                result->refresh_token = creds[i].refresh_token;
+                result->expires_at = creds[i].expires_at;
+            }
+        } else {
+            free(creds[i].access_token);
+            free(creds[i].refresh_token);
         }
+        free(creds[i].user);
+    }
+
+    if (result == NULL) {
+        set_err_details(ctx, MH_OAUTH_CRED_USER_NOT_FOUND, user);
+        return NULL;
     }
 
-    result = mh_xmalloc(sizeof *result);
     result->ctx = ctx;
-    result->access_token = access;
-    result->refresh_token = refresh;
-    result->expires_at = expires_at;
+    result->user = NULL;
 
     return result;
 }
@@ -1036,7 +1169,7 @@ parse_json(jsmntok_t **tokens, size_t *tokens_len,
            of the response body. */
         *tokens = mh_xrealloc(*tokens, *tokens_len * sizeof **tokens);
     }
-    if (r == 0) {
+    if (r <= 0) {
         return FALSE;
     }