Adding domain description
[hcoop/domtool2.git] / src / openssl.sml
index c0a24d4..aad3925 100644 (file)
@@ -28,7 +28,7 @@ exception OpenSSL of string
 
 type context = (ST_ssl_ctx_st.tag, C_Int.rw) C_Int.su_obj C_Int.ptr'
 type bio = (ST_bio_st.tag, C_Int.rw) C_Int.su_obj C_Int.ptr'
-type listener = MLRep.Signed.int
+type listener = bio
 
 fun ssl_err s =
     let
@@ -59,65 +59,207 @@ fun ssl_err s =
 
 val readBuf : (C.uchar, C.rw) C.obj C.ptr' = C.alloc' C.S.uchar (Word.fromInt Config.bufSize)
 val bufSize = Int32.fromInt Config.bufSize
+val one = Int32.fromInt 1
+val four = Int32.fromInt 4
 
-fun readOne bio =
+val eight = Word.fromInt 8
+val sixteen = Word.fromInt 16
+val twentyfour = Word.fromInt 24
+
+val mask1 = Word32.fromInt 255
+
+fun readChar bio =
+    let
+       val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, one)
+    in
+       if r = 0 then
+           NONE
+       else if r < 0 then
+           (ssl_err "BIO_read";
+            raise OpenSSL "BIO_read failed")
+       else
+           SOME (chr (Compat.Char.toInt (C.Get.uchar'
+                                             (C.Ptr.sub' C.S.uchar (readBuf, 0)))))
+    end
+
+val charToWord = Word32.fromLargeWord o Compat.Char.toLargeWord
+
+fun readInt bio =
+    let
+       val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, four)
+    in
+       if r = 0 then
+           NONE
+       else if r < 0 then
+           (ssl_err "BIO_read";
+            raise OpenSSL "BIO_read failed")
+       else
+           SOME (Word32.toInt
+                     (Word32.+
+                      (charToWord (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0))),
+                       Word32.+
+                       (Word32.<< (charToWord (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 1))),
+                                   eight),
+                        Word32.+
+                        (Word32.<< (charToWord (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 2))),
+                                    sixteen),
+                         Word32.<< (charToWord (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 3))),
+                                    twentyfour))))))
+    end
+
+fun readLen (bio, len) =
+    if len = 0 then
+       SOME ""
+    else
+       let
+           val buf =
+               if len > Config.bufSize then
+                   C.alloc' C.S.uchar (Word.fromInt len)
+               else
+                   readBuf
+
+           fun cleanup () =
+               if len > Config.bufSize then
+                   C.free' buf
+               else
+                   ()
+
+           fun loop (buf', needed) =
+               let
+                   val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' buf, Int32.fromInt len)
+               in
+                   if r = 0 then
+                       (cleanup (); NONE)
+                   else if r < 0 then
+                       (cleanup ();
+                        ssl_err "BIO_read";
+                        raise OpenSSL "BIO_read failed")
+                   else if r = needed then
+                       SOME (CharVector.tabulate (Int32.toInt needed,
+                                               fn i => chr (Compat.Char.toInt (C.Get.uchar'
+                                                                                   (C.Ptr.sub' C.S.uchar (buf, i))))))
+                   else
+                       loop (C.Ptr.|+! C.S.uchar (buf', Int32.toInt r), needed - r)
+               end
+       in
+           loop (buf, Int32.fromInt len)
+           before cleanup ()
+       end
+
+fun readChunk bio =
     let
        val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, bufSize)
     in
        if r = 0 then
            NONE
        else if r < 0 then
-           raise OpenSSL "BIO_read failed"
+           (ssl_err "BIO_read";
+            raise OpenSSL "BIO_read failed")
        else
            SOME (CharVector.tabulate (Int32.toInt r,
-                                   fn i => chr (Word32.toInt (C.Get.uchar'
-                                                                  (C.Ptr.sub' C.S.uchar (readBuf, i))))))
+                                      fn i => chr (Compat.Char.toInt (C.Get.uchar'
+                                                                          (C.Ptr.sub' C.S.uchar (readBuf, i))))))
     end
 
-fun writeAll (bio, s) =
+fun readString bio =
+    case readInt bio of
+       NONE => NONE
+      | SOME len => readLen (bio, len)
+
+fun writeChar (bio, ch) =
     let
-       val buf = ZString.dupML' s
+       val _ = C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0),
+                             Compat.Char.fromInt (ord ch))
 
-       fun loop (buf, len) =
+       fun trier () =
            let
-               val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' buf, len)
+               val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' readBuf, one)
            in
-               if r = len then
-                   ()
-               else if r <= 0 then
-                   (C.free' buf;
-                    raise OpenSSL "BIO_write failed")
+               if r = 0 then
+                   trier ()
+               else if r < 0 then
+                   (ssl_err "BIO_write";
+                    raise OpenSSL "BIO_write")
                else
-                   loop (C.Ptr.|+! C.S.uchar (buf, Int32.toInt r), Int32.- (len, r))
+                   ()
            end
     in
-       loop (buf, Int32.fromInt (size s));
-       C.free' buf
+       trier ()
     end
 
-fun context (chain, key, root) =
+val wordToChar = Compat.Char.fromLargeWord o Word32.toLargeWord
+
+fun writeInt (bio, n) =
+    let
+       val w = Word32.fromInt n
+
+       val _ = (C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0),
+                              wordToChar (Word32.andb (w, mask1)));
+                C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 1),
+                              wordToChar (Word32.andb (Word32.>> (w, eight), mask1)));
+                C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 2),
+                              wordToChar (Word32.andb (Word32.>> (w, sixteen), mask1)));
+                C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 3),
+                              wordToChar (Word32.andb (Word32.>> (w, twentyfour), mask1))))
+
+       fun trier (buf, count) =
+           let
+               val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' buf, count)
+           in
+               if r < 0 then
+                   (ssl_err "BIO_write";
+                    raise OpenSSL "BIO_write")
+               else if r = count then
+                   ()
+               else
+                   trier (C.Ptr.|+! C.S.uchar (buf, Int32.toInt r), count - r)
+           end
+    in
+       trier (readBuf, 4)
+    end        
+
+fun writeString' (bio, s) =
+    if size s = 0 then
+       ()
+    else
+       let
+           val buf = ZString.dupML' s
+       in
+           if F_OpenSSL_SML_puts.f' (bio, buf) <= 0 then
+               (C.free' buf;
+                ssl_err "BIO_puts";
+                raise OpenSSL "BIO_puts")
+           else
+               C.free' buf
+       end
+
+fun writeString (bio, s) =
+    (writeInt (bio, size s);
+     writeString' (bio, s))
+
+fun context printErr (chain, key, root) =
     let
        val context = F_OpenSSL_SML_CTX_new.f' (F_OpenSSL_SML_SSLv23_method.f' ())
     in
        if C.Ptr.isNull' context then
-           (ssl_err "Error creating SSL context";
+           (if printErr then ssl_err "Error creating SSL context" else ();
             raise OpenSSL "Can't create SSL context")
        else if F_OpenSSL_SML_use_certificate_chain_file.f' (context,
                                                             ZString.dupML' chain)
                = 0 then
-           (ssl_err "Error using certificate chain";
+           (if printErr then ssl_err "Error using certificate chain" else ();
             F_OpenSSL_SML_CTX_free.f' context;
             raise OpenSSL "Can't load certificate chain")
        else if F_OpenSSL_SML_use_PrivateKey_file.f' (context,
                                                      ZString.dupML' key)
                = 0 then
-           (ssl_err "Error using private key";
+           (if printErr then ssl_err "Error using private key" else ();
             F_OpenSSL_SML_CTX_free.f' context;
             raise OpenSSL "Can't load private key")
        else if F_OpenSSL_SML_load_verify_locations.f' (context,
                                                        ZString.dupML' root,
                                                        C.Ptr.null') = 0 then
-           (ssl_err "Error loading trust store";
+           (if printErr then ssl_err "Error loading trust store" else ();
             F_OpenSSL_SML_CTX_free.f' context;
             raise OpenSSL "Can't load trust store")
        else
@@ -146,32 +288,41 @@ fun connect (context, hostname) =
 
 fun close bio = F_OpenSSL_SML_free_all.f' bio
 
-fun listen (port, qsize) = F_OpenSSL_SML_tcp_listen.f' (Int32.fromInt port, Int32.fromInt qsize)
-fun shutdown sock = F_OpenSSL_SML_shutdown.f' sock
-
-fun accept (context, sock) =
-     let
-        val sock' = F_OpenSSL_SML_accept.f' sock
-     in
-        if Int32.< (sock', Int32.fromInt 0) then
-            NONE
-        else let
-                val bio = F_OpenSSL_SML_new_socket.f' sock'
-                val ssl = F_OpenSSL_SML_SSL_new.f' context
-            in
-                if C.Ptr.isNull' bio then
-                    (ssl_err "Error initializating accepter";
-                     F_OpenSSL_SML_free_all.f' bio;
-                     raise OpenSSL "Can't initialize accepter")
-                else if (F_OpenSSL_SML_SSL_set_bio.f' (ssl, bio, bio);
-                         F_OpenSSL_SML_SSL_accept.f' ssl) <= 0 then
-                    (ssl_err "Error accepting connection";
-                     F_OpenSSL_SML_free_all.f' bio;
-                     raise OpenSSL "Can't accept connection")
-                else
-                    SOME bio
-            end
-     end
+fun listen (context, port) =
+    let
+       val port = ZString.dupML' (Int.toString port)
+       val listener = F_OpenSSL_SML_new_accept.f' (context, port)
+    in
+       C.free' port;
+       if C.Ptr.isNull' listener then
+           (ssl_err "Null listener";
+            raise OpenSSL "Null listener")
+       else if F_OpenSSL_SML_do_accept.f' listener <= 0 then
+           (ssl_err "Error initializing listener";
+            close listener;
+            raise OpenSSL "Can't initialize listener")
+       else
+           listener
+    end
+
+val shutdown = close
+
+fun accept listener =
+    if F_OpenSSL_SML_do_accept.f' listener <= 0 then
+       NONE
+    else
+       let
+           val bio = F_OpenSSL_SML_pop.f' listener
+       in
+           if C.Ptr.isNull' bio then
+               (ssl_err "Null accepted";
+                raise OpenSSL "Null accepted")
+           else if F_OpenSSL_SML_do_handshake.f' bio <= 0 then
+               (ssl_err "Handshake failed";
+                raise OpenSSL "Handshake failed")
+           else
+               SOME bio
+       end
 
 fun peerCN bio =
     let