Basic SSL connection going
[hcoop/domtool2.git] / src / openssl.sml
diff --git a/src/openssl.sml b/src/openssl.sml
new file mode 100644 (file)
index 0000000..c0a24d4
--- /dev/null
@@ -0,0 +1,191 @@
+(* HCoop Domtool (http://hcoop.sourceforge.net/)
+ * Copyright (c) 2006, Adam Chlipala
+ *
+ * This program is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU General Public License
+ * as published by the Free Software Foundation; either version 2
+ * of the License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+ *)
+
+(* OpenSSL *)
+
+structure OpenSSL :> OPENSSL = struct
+
+val () = (F_OpenSSL_SML_init.f' ();
+         F_OpenSSL_SML_load_error_strings.f' ();
+         F_OpenSSL_SML_load_BIO_strings.f' ())
+
+exception OpenSSL of string
+
+type context = (ST_ssl_ctx_st.tag, C_Int.rw) C_Int.su_obj C_Int.ptr'
+type bio = (ST_bio_st.tag, C_Int.rw) C_Int.su_obj C_Int.ptr'
+type listener = MLRep.Signed.int
+
+fun ssl_err s =
+    let
+       val err = F_OpenSSL_SML_get_error.f ()
+
+       val lib = F_OpenSSL_SML_lib_error_string.f err
+       val func = F_OpenSSL_SML_func_error_string.f err
+       val reason = F_OpenSSL_SML_reason_error_string.f err
+    in
+       print s;
+       print "\nReason: ";
+       if C.Ptr.isNull lib then
+           ()
+       else
+           (print (ZString.toML lib);
+            print ":");
+       if C.Ptr.isNull func then
+           ()
+       else
+           (print (ZString.toML func);
+            print ":");
+       if C.Ptr.isNull reason then
+           ()
+       else
+           print (ZString.toML reason);
+       print "\n"
+    end
+
+val readBuf : (C.uchar, C.rw) C.obj C.ptr' = C.alloc' C.S.uchar (Word.fromInt Config.bufSize)
+val bufSize = Int32.fromInt Config.bufSize
+
+fun readOne bio =
+    let
+       val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, bufSize)
+    in
+       if r = 0 then
+           NONE
+       else if r < 0 then
+           raise OpenSSL "BIO_read failed"
+       else
+           SOME (CharVector.tabulate (Int32.toInt r,
+                                   fn i => chr (Word32.toInt (C.Get.uchar'
+                                                                  (C.Ptr.sub' C.S.uchar (readBuf, i))))))
+    end
+
+fun writeAll (bio, s) =
+    let
+       val buf = ZString.dupML' s
+
+       fun loop (buf, len) =
+           let
+               val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' buf, len)
+           in
+               if r = len then
+                   ()
+               else if r <= 0 then
+                   (C.free' buf;
+                    raise OpenSSL "BIO_write failed")
+               else
+                   loop (C.Ptr.|+! C.S.uchar (buf, Int32.toInt r), Int32.- (len, r))
+           end
+    in
+       loop (buf, Int32.fromInt (size s));
+       C.free' buf
+    end
+
+fun context (chain, key, root) =
+    let
+       val context = F_OpenSSL_SML_CTX_new.f' (F_OpenSSL_SML_SSLv23_method.f' ())
+    in
+       if C.Ptr.isNull' context then
+           (ssl_err "Error creating SSL context";
+            raise OpenSSL "Can't create SSL context")
+       else if F_OpenSSL_SML_use_certificate_chain_file.f' (context,
+                                                            ZString.dupML' chain)
+               = 0 then
+           (ssl_err "Error using certificate chain";
+            F_OpenSSL_SML_CTX_free.f' context;
+            raise OpenSSL "Can't load certificate chain")
+       else if F_OpenSSL_SML_use_PrivateKey_file.f' (context,
+                                                     ZString.dupML' key)
+               = 0 then
+           (ssl_err "Error using private key";
+            F_OpenSSL_SML_CTX_free.f' context;
+            raise OpenSSL "Can't load private key")
+       else if F_OpenSSL_SML_load_verify_locations.f' (context,
+                                                       ZString.dupML' root,
+                                                       C.Ptr.null') = 0 then
+           (ssl_err "Error loading trust store";
+            F_OpenSSL_SML_CTX_free.f' context;
+            raise OpenSSL "Can't load trust store")
+       else
+           context
+    end
+
+fun connect (context, hostname) =
+    let
+       val bio = F_OpenSSL_SML_new_ssl_connect.f' context
+    in
+       if C.Ptr.isNull' bio then
+           (ssl_err ("Error initializating connection to " ^ hostname);
+            F_OpenSSL_SML_free_all.f' bio;
+            raise OpenSSL "Can't initialize connection")
+       else if F_OpenSSL_SML_set_conn_hostname.f' (bio, ZString.dupML' hostname) = 0 then
+           (ssl_err ("Error setting hostname: " ^ hostname);
+            F_OpenSSL_SML_free_all.f' bio;
+            raise OpenSSL "Can't set hostname")
+       else if F_OpenSSL_SML_do_connect.f' bio <= 0 then
+           (ssl_err ("Error connecting to " ^ hostname);
+            F_OpenSSL_SML_free_all.f' bio;
+            raise OpenSSL "Can't connect")
+       else
+           bio
+    end
+
+fun close bio = F_OpenSSL_SML_free_all.f' bio
+
+fun listen (port, qsize) = F_OpenSSL_SML_tcp_listen.f' (Int32.fromInt port, Int32.fromInt qsize)
+fun shutdown sock = F_OpenSSL_SML_shutdown.f' sock
+
+fun accept (context, sock) =
+     let
+        val sock' = F_OpenSSL_SML_accept.f' sock
+     in
+        if Int32.< (sock', Int32.fromInt 0) then
+            NONE
+        else let
+                val bio = F_OpenSSL_SML_new_socket.f' sock'
+                val ssl = F_OpenSSL_SML_SSL_new.f' context
+            in
+                if C.Ptr.isNull' bio then
+                    (ssl_err "Error initializating accepter";
+                     F_OpenSSL_SML_free_all.f' bio;
+                     raise OpenSSL "Can't initialize accepter")
+                else if (F_OpenSSL_SML_SSL_set_bio.f' (ssl, bio, bio);
+                         F_OpenSSL_SML_SSL_accept.f' ssl) <= 0 then
+                    (ssl_err "Error accepting connection";
+                     F_OpenSSL_SML_free_all.f' bio;
+                     raise OpenSSL "Can't accept connection")
+                else
+                    SOME bio
+            end
+     end
+
+fun peerCN bio =
+    let
+       val ssl = F_OpenSSL_SML_get_ssl.f' bio
+       val _ = if C.Ptr.isNull' ssl then
+                   raise OpenSSL "Null SSL"
+               else
+                   ()
+       val subj = F_OpenSSL_SML_get_peer_name.f' ssl
+    in
+       if C.Ptr.isNull' subj then
+           raise OpenSSL "Null CN result"
+       else
+           ZString.toML' subj
+    end
+
+end