Imported Upstream version 0.63.0
[hcoop/debian/courier-authlib.git] / authpgsqllib.c
index a01ab5f..b66b4f1 100644 (file)
 #define                SV_BEGIN_LEN            ((sizeof(SV_BEGIN_MARK))-1)
 #define                SV_END_LEN              ((sizeof(SV_END_MARK))-1)
 
-static const char rcsid[]="$Id: authpgsqllib.c,v 1.18 2006/10/28 19:22:52 mrsam Exp $";
+static const char rcsid[]="$Id: authpgsqllib.c,v 1.19 2008/12/18 12:08:25 mrsam Exp $";
 
 /* tom@minnesota.com */
 struct var_data {                      
        const char *name;
        const char *value;
        const size_t size;
-       size_t value_length;
        } ;
 
 /* tom@minnesota.com */
@@ -111,6 +110,32 @@ static PGresult *pgresult=0;
 
 static PGconn *pgconn=0;
 
+/*
+* session variables can be set once for the whole session
+*/
+
+static void set_session_options(void)
+{
+       const char *character_set=read_env("PGSQL_CHARACTER_SET"), *check;
+
+       if (character_set)
+       {
+               PQsetClientEncoding(pgconn, character_set);
+               check = pg_encoding_to_char(PQclientEncoding(pgconn));
+               if (strcmp(character_set, check) != 0)
+               {
+                       err("Cannot set Postgresql character set \"%s\", working with \"%s\"\n",
+                           character_set, check);
+               }
+               else
+               {
+                       DPRINTF("Install of a character set for Postgresql: %s", character_set);
+               }
+        }
+}
+
+
+
 /*
 static FILE *DEBUG=0;
 */
@@ -203,6 +228,7 @@ const       char *server_opt=0;
        fflush(DEBUG);
 */
 
+       set_session_options();
        return 0;
 
 }
@@ -218,20 +244,34 @@ void auth_pgsql_cleanup()
 
 static struct authpgsqluserinfo ui={0, 0, 0, 0, 0, 0, 0, 0};
 
-static void append_username(char *p, const char *username,
-                           const char *defdomain)
+static char *get_username_escaped(const char *username,
+                                 const char *defdomain)
 {
-       for (strcpy(p, username); *p; p++)
-               if (*p == '\'' || *p == '"' || *p == '\\' ||
-                   (int)(unsigned char)*p < ' ')
-                       *p=' '; /* No funny business */
-       if (strchr(username, '@') == 0 && defdomain && *defdomain)
-               strcat(strcpy(p, "@"), defdomain);
+       char *username_escaped;
+       int *error = NULL;
+
+       if (!defdomain)
+               defdomain="";
+
+        username_escaped=malloc(strlen(username)*2+2+strlen(defdomain));
+
+       if (!username_escaped)
+       {
+               perror("malloc");
+               return 0;
+       }
+
+       PQescapeStringConn(pgconn, username_escaped, username, strlen(username), error);
+
+       if (strchr(username, '@') == 0 && *defdomain)
+               strcat(strcat(username_escaped, "@"), defdomain);
+
+       return username_escaped;
 }
 
 /* tom@minnesota.com */
 static struct var_data *get_variable (const char *begin, size_t len,
-                                          struct var_data *vdt)
+                                     struct var_data *vdt)
 {
 struct var_data *vdp;
 
@@ -263,8 +303,6 @@ struct var_data *vdp;
                {
                        if (!vdp->value)
                                vdp->value = "";
-                       if (!vdp->value_length)         /* length cache */
-                               vdp->value_length = strlen (vdp->value);
                        return vdp;
                }
        
@@ -340,6 +378,8 @@ struct var_data     *v_ptr;
        q = source;
        while ( (p=strstr(q, SV_BEGIN_MARK)) )
        {
+               char *enc;
+
                e = strstr (p, SV_END_MARK);
                if (!e)
                {
@@ -374,10 +414,22 @@ struct var_data   *v_ptr;
                /* work on variable */
                v_ptr = get_variable (v_begin, v_size, vdt);
                if (!v_ptr) return -1;
-               
-               if ( (outfn (v_ptr->value, v_ptr->value_length, result)) == -1 )
+
+               enc=malloc(strlen(v_ptr->value)*2+1);
+
+               if (!enc)
                        return -1;
-               
+
+               PQescapeStringConn(pgconn, enc, v_ptr->value,
+                                  strlen(v_ptr->value), NULL);
+
+               if ( (outfn (enc, strlen(enc), result)) == -1 )
+               {
+                       free(enc);
+                       return -1;
+               }
+               free(enc);
+
                q = e + 1;
        }
 
@@ -392,7 +444,6 @@ struct var_data     *v_ptr;
 /* tom@minnesota.com */
 static char *parse_string (const char *source, struct var_data *vdt)
 {
-struct var_data *vdp   = NULL;
 char   *output_buf     = NULL,
        *pass_buf       = NULL;
 size_t buf_size        = 2;
@@ -405,11 +456,6 @@ size_t     buf_size        = 2;
                return NULL;
        }
 
-       /* zero var_data length cache - important! */
-       for (vdp=vdt; vdp->name; vdp++)
-               vdp->value_length = 0;
-
-
        /* phase 1 - count and validate string */
        if ( (parse_core (source, vdt, &ParsePlugin_counter, &buf_size)) != 0)
                return NULL;
@@ -434,108 +480,30 @@ size_t   buf_size        = 2;
        return output_buf;
 }
 
-/* tom@minnesota.com */
-static const char *get_localpart (const char *username)
+static char *get_localpart (const char *username)
 {
-size_t         lbuf    = 0;
-const char     *l_end, *p;
-char           *q;
-static char    localpart_buf[130];
-       
-       if (!username || *username == '\0')     return NULL;
-       
-       p = strchr(username,'@');
-       if (p)
-       {
-               if ((p-username) > 128)
-                       return NULL;
-               l_end = p;
-       }
-       else
-       {
-               if ((lbuf = strlen(username)) > 128)
-                       return NULL;
-               l_end = username + lbuf;
-       }
+       char *p=strdup(username);
+       char *q;
 
-       p=username;
-       q=localpart_buf;
-       
-       while (*p && p != l_end)
-               if (*p == '\"' || *p == '\\' ||
-                   *p == '\'' || (int)(unsigned char)*p < ' ')
-                       p++;
-               else
-                       *q++ = *p++;
+       if (!p)
+               return 0;
 
-       *q = '\0';
-       return localpart_buf;
-}
-
-/* tom@minnesota.com */
-static const char *get_domain (const char *username, const char *defdomain)
-{
-static char    domain_buf[260];
-const char     *p;
-char           *q;
-       
-       if (!username || *username == '\0')     return NULL;
-       p = strchr(username,'@');
-       
-       if (!p || *(p+1) == '\0')
-       {
-               if (defdomain && *defdomain)
-                       return defdomain;
-               else
-                       return NULL;
-       }
+       q=strchr(p, '@');
 
-       p++;
-       if ((strlen(p)) > 256)
-               return NULL;
-       
-       q = domain_buf;
-       while (*p)
-               if (*p == '\"' || *p == '\\' ||
-                   *p == '\'' || (int)(unsigned char)*p < ' ')
-                       p++;
-               else
-                       *q++ = *p++;
+       if (q)
+               *q=0;
 
-       *q = '\0';
-       return domain_buf;
+       return p;
 }
 
-/* tom@minnesota.com */
-
-static const char *validate_password (const char *password)
+static const char *get_domain (const char *username, const char *defdomain)
 {
-static char pass_buf[2][540]; /* Use two buffers, see parse_chpass_clause */
-static int next_pass=0;
-const char     *p;
-char           *q, *endq;
-       
-       if (!password || *password == '\0' || (strlen(password)) > 256)
-               return NULL;
-       
-       next_pass= 1-next_pass;
+       const char *p=strchr(username, '@');
 
-       p = password;
-       q = pass_buf[next_pass];
-       endq = q + sizeof pass_buf[next_pass];
-       
-       while (*p && q < endq)
-       {
-               if (*p == '\"' || *p == '\\' || *p == '\'')
-                       *q++ = '\\';
-               *q++ = *p++;
-       }
-       
-       if (q >= endq)
-               return NULL;
-       
-       *q = '\0';
-       return pass_buf[next_pass];
+       if (p)
+               return p+1;
+
+       return defdomain;
 }
 
 /* tom@minnesota.com */
@@ -543,23 +511,34 @@ static char *parse_select_clause (const char *clause, const char *username,
                                  const char *defdomain,
                                  const char *service)
 {
-static struct var_data vd[]={
-           {"local_part",      NULL,   sizeof("local_part"),   0},
-           {"domain",          NULL,   sizeof("domain"),       0},
-           {"service",         NULL,   sizeof("service"),      0},
-           {NULL,              NULL,   0,                      0}};
+       char *localpart, *ret;
+       static struct var_data vd[]={
+               {"local_part",  NULL,   sizeof("local_part")},
+               {"domain",      NULL,   sizeof("domain")},
+               {"service",     NULL,   sizeof("service")},
+               {NULL,          NULL,   0}};
 
        if (clause == NULL || *clause == '\0' ||
            !username || *username == '\0')
                return NULL;
        
-       vd[0].value     = get_localpart (username);
+       localpart=get_localpart(username);
+       if (!localpart)
+               return NULL;
+
+       vd[0].value     = localpart;
        vd[1].value     = get_domain (username, defdomain);
-       if (!vd[0].value || !vd[1].value)
+
+       if (!vd[1].value)
+       {
+               free(localpart);
                return NULL;
+       }
        vd[2].value     = service;
 
-       return (parse_string (clause, vd));
+       ret=parse_string (clause, vd);
+       free(localpart);
+       return ret;
 }
 
 /* tom@minnesota.com */
@@ -567,27 +546,38 @@ static char *parse_chpass_clause (const char *clause, const char *username,
                                  const char *defdomain, const char *newpass,
                                  const char *newpass_crypt)
 {
-static struct var_data vd[]={
-           {"local_part",      NULL,   sizeof("local_part"),           0},
-           {"domain",          NULL,   sizeof("domain"),               0},
-           {"newpass",         NULL,   sizeof("newpass"),              0},
-           {"newpass_crypt",   NULL,   sizeof("newpass_crypt"),        0},
-           {NULL,              NULL,   0,                              0}};
+       char *localpart, *ret;
+
+       static struct var_data vd[]={
+               {"local_part",  NULL,   sizeof("local_part")},
+               {"domain",      NULL,   sizeof("domain")},
+               {"newpass",     NULL,   sizeof("newpass")},
+               {"newpass_crypt", NULL, sizeof("newpass_crypt")},
+               {NULL,          NULL,   0}};
 
        if (clause == NULL || *clause == '\0'           ||
            !username || *username == '\0'              ||
            !newpass || *newpass == '\0'                ||
            !newpass_crypt || *newpass_crypt == '\0')   return NULL;
 
-       vd[0].value     = get_localpart (username);
+       localpart=get_localpart(username);
+       if (!localpart)
+               return NULL;
+
+       vd[0].value     = localpart;
        vd[1].value     = get_domain (username, defdomain);
-       vd[2].value     = validate_password (newpass);
-       vd[3].value     = validate_password (newpass_crypt);
+       vd[2].value     = newpass;
+       vd[3].value     = newpass_crypt;
        
-       if (!vd[0].value || !vd[1].value ||
-           !vd[2].value || !vd[3].value)       return NULL;
+       if (!vd[1].value || !vd[2].value || !vd[3].value)
+       {
+               free(localpart);
+               return NULL;
+       }
 
-       return (parse_string (clause, vd));
+       ret=parse_string (clause, vd);
+       free(localpart);
+       return ret;
 }
 
 static void initui()
@@ -615,11 +605,20 @@ static void initui()
 struct authpgsqluserinfo *auth_pgsql_getuserinfo(const char *username,
                                                 const char *service)
 {
-const char *defdomain, *select_clause;
-char   *querybuf, *p;
+       const char *defdomain, *select_clause;
+       char    *querybuf;
+       size_t query_size;
+       char dummy_buf[1];
+
+#define SELECT_QUERY "SELECT %s, %s, %s, %s, %s, %s, %s, %s, %s, %s FROM %s WHERE %s = '%s' %s%s%s", \
+               login_field, crypt_field, clear_field, \
+               uid_field, gid_field, home_field, maildir_field, \
+               quota_field, \
+               name_field, \
+               options_field, \
+               user_table, login_field, username_escaped, \
+               where_pfix, where_clause, where_sfix
 
-static const char query[]=
-       "SELECT %s, %s, %s, %s, %s, %s, %s, %s, %s, %s FROM %s WHERE %s = '";
 
        if (do_connect())       return (0);
 
@@ -648,6 +647,9 @@ static const char query[]=
                                *options_field,
                                *where_clause;
 
+               const char *where_pfix, *where_sfix;
+               char *username_escaped;
+
                user_table=read_env("PGSQL_USER_TABLE");
 
                if (!user_table)
@@ -697,43 +699,32 @@ static const char query[]=
                where_clause=read_env("PGSQL_WHERE_CLAUSE");
                if (!where_clause) where_clause = "";
 
-               querybuf=malloc(sizeof(query) + 100
-                               + 2 * strlen(login_field)
-                               + strlen(crypt_field)
-                               + strlen(clear_field)
-                               + strlen(uid_field) + strlen(gid_field)
-                               + strlen(home_field)
-                               + strlen(maildir_field)
-                               + strlen(quota_field)
-                               + strlen(name_field)
-                               + strlen(options_field)
-                               + strlen(user_table)
-                               + strlen(username)
-                               + strlen(defdomain)
-                               + strlen(where_clause));
+               where_pfix=where_sfix="";
 
-               if (!querybuf)
+               if (strcmp(where_clause, ""))
                {
-                       perror("malloc");
-                       return (0);
+                       where_pfix=" AND (";
+                       where_sfix=")";
                }
 
-               sprintf(querybuf, query, login_field, crypt_field, clear_field,
-                       uid_field, gid_field, home_field, maildir_field,
-                       quota_field,
-                       name_field,
-                       options_field,
-                       user_table, login_field);
-               p=querybuf+strlen(querybuf);
+               username_escaped=get_username_escaped(username, defdomain);
 
-               append_username(p, username, defdomain);
-               strcat(p, "'");
-               
-               if (strcmp(where_clause, "")) {
-                       strcat(p, " AND (");
-                       strcat(p, where_clause);
-                       strcat(p, ")");
+               if (!username_escaped)
+                       return 0;
+
+               query_size=snprintf(dummy_buf, 1, SELECT_QUERY);
+
+               querybuf=malloc(query_size+1);
+
+               if (!querybuf)
+               {
+                       free(username_escaped);
+                       perror("malloc");
+                       return 0;
                }
+
+               snprintf(querybuf, query_size+1, SELECT_QUERY);
+               free(username_escaped);
        }
        else
        {
@@ -849,12 +840,17 @@ int auth_pgsql_setpass(const char *user, const char *pass,
                       const char *oldpass)
 {
        char *newpass_crypt;
-       const char *p;
-       int l;
        char *sql_buf;
-       const char *comma;
+       size_t sql_buf_size;
+       char dummy_buf[1];
        int rc=0;
 
+       char *clear_escaped;
+       char *crypt_escaped;
+       int  *error = NULL;
+
+       char *username_escaped;
+
        const char *clear_field=NULL;
        const char *crypt_field=NULL;
        const char *defdomain=NULL;
@@ -870,17 +866,30 @@ int auth_pgsql_setpass(const char *user, const char *pass,
        if (!(newpass_crypt=authcryptpasswd(pass, oldpass)))
                return (-1);
 
-       for (l=0, p=pass; *p; p++)
-       {
-               if ((int)(unsigned char)*p < ' ')
-               {
-                       free(newpass_crypt);
-                       return (-1);
-               }
-               if (*p == '"' || *p == '\\')
-                       ++l;
-               ++l;
-       }
+       clear_escaped=malloc(strlen(pass)*2+1);
+
+        if (!clear_escaped)
+        {
+                perror("malloc");
+                free(newpass_crypt);
+                return -1;
+        }
+
+        crypt_escaped=malloc(strlen(newpass_crypt)*2+1);
+
+        if (!crypt_escaped)
+        {
+                perror("malloc");
+                free(clear_escaped);
+                free(newpass_crypt);
+                return -1;
+        }
+
+        PQescapeStringConn(pgconn, clear_escaped, pass, strlen(pass), error);
+        PQescapeStringConn(pgconn, crypt_escaped,
+                                 newpass_crypt, strlen(newpass_crypt), error);
+
+
 
        /* tom@minnesota.com */
        chpass_clause=read_env("PGSQL_CHPASS_CLAUSE");
@@ -893,13 +902,50 @@ int auth_pgsql_setpass(const char *user, const char *pass,
                crypt_field=read_env("PGSQL_CRYPT_PWFIELD");
                clear_field=read_env("PGSQL_CLEAR_PWFIELD");
                where_clause=read_env("PGSQL_WHERE_CLAUSE");
-               sql_buf=malloc(strlen(crypt_field ? crypt_field:"")
-                              + strlen(clear_field ? clear_field:"")
-                              + strlen(defdomain ? defdomain:"")
-                              + strlen(login_field) + l + strlen(newpass_crypt)
-                              + strlen(user_table)
-                              + strlen(where_clause ? where_clause:"")
-                              + 200);
+
+               username_escaped=get_username_escaped(user, defdomain);
+
+               if (!username_escaped)
+                       return -1;
+
+               if (!where_clause)
+                       where_clause="";
+
+               if (!crypt_field)
+                       crypt_field="";
+
+               if (!clear_field)
+                       clear_field="";
+
+#define DEFAULT_SETPASS_UPDATE \
+               "UPDATE %s SET %s%s%s%s %s %s%s%s%s WHERE %s='%s' %s%s%s", \
+                       user_table,                                     \
+                       *clear_field ? clear_field:"",                  \
+                       *clear_field ? "='":"",                         \
+                       *clear_field ? clear_escaped:"",                \
+                       *clear_field ? "'":"",                          \
+                                                                       \
+                       *clear_field && *crypt_field ? ",":"",          \
+                                                                       \
+                       *crypt_field ? crypt_field:"",                  \
+                       *crypt_field ? "='":"",                         \
+                       *crypt_field ? crypt_escaped:"",                \
+                       *crypt_field ? "'":"",                          \
+                                                                       \
+                       login_field, username_escaped,                  \
+                       *where_clause ? " AND (":"", where_clause,      \
+                       *where_clause ? ")":""
+
+
+               sql_buf_size=snprintf(dummy_buf, 1, DEFAULT_SETPASS_UPDATE);
+
+               sql_buf=malloc(sql_buf_size+1);
+
+               if (sql_buf)
+                       snprintf(sql_buf, sql_buf_size+1,
+                                DEFAULT_SETPASS_UPDATE);
+
+               free(username_escaped);
        }
        else
        {
@@ -912,62 +958,10 @@ int auth_pgsql_setpass(const char *user, const char *pass,
 
        if (!sql_buf)
        {
+               free(clear_escaped);
                free(newpass_crypt);
                return (-1);
        }
-
-       if (!chpass_clause) /* tom@minnesota.com */
-       {
-               sprintf(sql_buf, "UPDATE %s SET", user_table);
-
-               comma="";
-
-               if (clear_field && *clear_field)
-               {
-                       char *q;
-
-                       strcat(strcat(strcat(sql_buf, " "), clear_field),
-                              "='");
-
-                       q=sql_buf+strlen(sql_buf);
-                       while (*pass)
-                       {
-                               if (*pass == '"' || *pass == '\\')
-                                       *q++= '\\';
-                               *q++ = *pass++;
-                       }
-                       strcpy(q, "'");
-                       comma=", ";
-               }
-
-               if (crypt_field && *crypt_field)
-               {
-                       strcat(strcat(strcat(strcat(strcat(strcat(sql_buf, comma),
-                                                          " "),
-                                                   crypt_field),
-                                            "='"),
-                                     newpass_crypt),
-                              "'");
-               }
-               free(newpass_crypt);
-
-               strcat(strcat(strcat(sql_buf, " WHERE "),
-                             login_field),
-                      "='");
-
-               append_username(sql_buf+strlen(sql_buf), user, defdomain);
-
-               strcat(sql_buf, "'");
-
-               if (where_clause && *where_clause)
-               {
-                       strcat(sql_buf, " AND (");
-                       strcat(sql_buf, where_clause);
-                       strcat(sql_buf, ")");
-               }
-
-       } /* end of: if (!chpass_clause) */
-
        if (courier_authdebug_login_level >= 2)
        {
                DPRINTF("setpass SQL: %s", sql_buf);
@@ -980,6 +974,9 @@ int auth_pgsql_setpass(const char *user, const char *pass,
                auth_pgsql_cleanup();
        }
        PQclear(pgresult);
+       free(clear_escaped);
+       free(crypt_escaped);
+       free(newpass_crypt);
        free(sql_buf);
        return (rc);
 }
@@ -994,11 +991,9 @@ void auth_pgsql_enumerate( void(*cb_func)(const char *name,
                           void *void_arg)
 {
        const char *select_clause, *defdomain;
-       char    *querybuf, *p;
+       char    *querybuf;
 
-static const char query[]=
-       "SELECT %s, %s, %s, %s, %s, %s FROM %s WHERE 1=1";
-int i,n;
+       int i,n;
 
        if (do_connect())       return;
 
@@ -1019,6 +1014,8 @@ int i,n;
                                *maildir_field,
                                *options_field,
                                *where_clause;
+               char dummy_buf[1];
+               size_t query_len;
 
                user_table=read_env("PGSQL_USER_TABLE");
 
@@ -1050,14 +1047,18 @@ int i,n;
                where_clause=read_env("PGSQL_WHERE_CLAUSE");
                if (!where_clause) where_clause = "";
 
-               querybuf=malloc(sizeof(query) + 100
-                               + strlen(login_field)
-                               + strlen(uid_field) + strlen(gid_field)
-                               + strlen(home_field)
-                               + strlen(maildir_field)
-                               + strlen(options_field)
-                               + strlen(user_table)
-                               + strlen(where_clause));
+#define DEFAULT_ENUMERATE_QUERY \
+               "SELECT %s, %s, %s, %s, %s, %s FROM %s %s%s",\
+                       login_field, uid_field, gid_field,              \
+                       home_field, maildir_field,                      \
+                       options_field, user_table,                      \
+                       *where_clause ? " WHERE ":"",                   \
+                       where_clause
+
+
+               query_len=snprintf(dummy_buf, 1, DEFAULT_ENUMERATE_QUERY);
+
+               querybuf=malloc(query_len+1);
 
                if (!querybuf)
                {
@@ -1065,16 +1066,7 @@ int i,n;
                        return;
                }
 
-               sprintf(querybuf, query, login_field, 
-                       uid_field, gid_field, home_field, maildir_field,
-                       options_field, user_table);
-               p=querybuf+strlen(querybuf);
-               
-               if (strcmp(where_clause, "")) {
-                       strcat(p, " AND (");
-                       strcat(p, where_clause);
-                       strcat(p, ")");
-               }
+               snprintf(querybuf, query_len+1, DEFAULT_ENUMERATE_QUERY);
        }
        else
        {