X-Git-Url: https://git.hcoop.net/hcoop/domtool2.git/blobdiff_plain/605347124bd39d347058bc3bd5356c184f654b1d..36e42cb86393a7b9e333ecd7edfbdd16c7d9a1ac:/src/openssl.sml diff --git a/src/openssl.sml b/src/openssl.sml index 627b405..7a062f1 100644 --- a/src/openssl.sml +++ b/src/openssl.sml @@ -59,42 +59,174 @@ fun ssl_err s = val readBuf : (C.uchar, C.rw) C.obj C.ptr' = C.alloc' C.S.uchar (Word.fromInt Config.bufSize) val bufSize = Int32.fromInt Config.bufSize +val one = Int32.fromInt 1 +val four = Int32.fromInt 4 -fun readOne bio = +val eight = Word.fromInt 8 +val sixteen = Word.fromInt 16 +val twentyfour = Word.fromInt 24 + +val mask1 = Word32.fromInt 255 + +fun readChar bio = + let + val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, one) + in + if r = 0 then + NONE + else if r < 0 then + (ssl_err "BIO_read"; + raise OpenSSL "BIO_read failed") + else + SOME (chr (Word32.toInt (C.Get.uchar' + (C.Ptr.sub' C.S.uchar (readBuf, 0))))) + end + +fun readInt bio = + let + val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, four) + in + if r = 0 then + NONE + else if r < 0 then + (ssl_err "BIO_read"; + raise OpenSSL "BIO_read failed") + else + SOME (Word32.toInt + (Word32.+ + (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0)), + Word32.+ + (Word32.<< (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 1)), + eight), + Word32.+ + (Word32.<< (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 2)), + sixteen), + Word32.<< (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 3)), + twentyfour)))))) + end + +fun readLen (bio, len) = + let + val buf = + if len > Config.bufSize then + C.alloc' C.S.uchar (Word.fromInt len) + else + readBuf + + fun cleanup () = + if len > Config.bufSize then + C.free' buf + else + () + + fun loop (buf', needed) = + let + val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' buf, Int32.fromInt len) + in + if r = 0 then + (cleanup (); NONE) + else if r < 0 then + (cleanup (); + ssl_err "BIO_read"; + raise OpenSSL "BIO_read failed") + else if r = needed then + SOME (CharVector.tabulate (Int32.toInt needed, + fn i => chr (Word32.toInt (C.Get.uchar' + (C.Ptr.sub' C.S.uchar (buf, i)))))) + else + loop (C.Ptr.|+! C.S.uchar (buf', Int32.toInt r), needed - r) + end + in + loop (buf, Int32.fromInt len) + before cleanup () + end + +fun readChunk bio = let val r = F_OpenSSL_SML_read.f' (bio, C.Ptr.inject' readBuf, bufSize) in if r = 0 then NONE else if r < 0 then - raise OpenSSL "BIO_read failed" + (ssl_err "BIO_read"; + raise OpenSSL "BIO_read failed") else SOME (CharVector.tabulate (Int32.toInt r, fn i => chr (Word32.toInt (C.Get.uchar' (C.Ptr.sub' C.S.uchar (readBuf, i)))))) end -fun writeAll (bio, s) = +fun readString bio = + case readInt bio of + NONE => NONE + | SOME len => readLen (bio, len) + +fun writeChar (bio, ch) = let - val buf = ZString.dupML' s + val _ = C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0), + Word32.fromInt (ord ch)) - fun loop (buf, len) = + fun trier () = let - val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' buf, len) + val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' readBuf, one) in - if r = len then + if r = 0 then + trier () + else if r < 0 then + (ssl_err "BIO_write"; + raise OpenSSL "BIO_write") + else + () + end + in + trier () + end + +fun writeInt (bio, n) = + let + val w = Word32.fromInt n + + val _ = (C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 0), + Word32.andb (w, mask1)); + C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 1), + Word32.andb (Word32.>> (w, eight), mask1)); + C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 2), + Word32.andb (Word32.>> (w, sixteen), mask1)); + C.Set.uchar' (C.Ptr.sub' C.S.uchar (readBuf, 3), + Word32.andb (Word32.>> (w, twentyfour), mask1))) + + fun trier (buf, count) = + let + val r = F_OpenSSL_SML_write.f' (bio, C.Ptr.inject' buf, count) + in + if r < 0 then + (ssl_err "BIO_write"; + raise OpenSSL "BIO_write") + else if r = count then () - else if r <= 0 then - (C.free' buf; - raise OpenSSL "BIO_write failed") else - loop (C.Ptr.|+! C.S.uchar (buf, Int32.toInt r), Int32.- (len, r)) + trier (C.Ptr.|+! C.S.uchar (buf, Int32.toInt r), count - r) end in - loop (buf, Int32.fromInt (size s)); - C.free' buf + trier (readBuf, 4) + end + +fun writeString' (bio, s) = + let + val buf = ZString.dupML' s + in + if F_OpenSSL_SML_puts.f' (bio, buf) <= 0 then + (C.free' buf; + ssl_err "BIO_puts"; + raise OpenSSL "BIO_puts") + else + C.free' buf end +fun writeString (bio, s) = + (writeInt (bio, size s); + writeString' (bio, s)) + fun context (chain, key, root) = let val context = F_OpenSSL_SML_CTX_new.f' (F_OpenSSL_SML_SSLv23_method.f' ()) @@ -169,7 +301,18 @@ fun accept listener = if F_OpenSSL_SML_do_accept.f' listener <= 0 then NONE else - SOME (F_OpenSSL_SML_pop.f' listener) + let + val bio = F_OpenSSL_SML_pop.f' listener + in + if C.Ptr.isNull' bio then + (ssl_err "Null accepted"; + raise OpenSSL "Null accepted") + else if F_OpenSSL_SML_do_handshake.f' bio <= 0 then + (ssl_err "Handshake failed"; + raise OpenSSL "Handshake failed") + else + SOME bio + end fun peerCN bio = let