fwtool: allow multiple nodes per rule
[hcoop/domtool2.git] / src / plugins / firewall.sml
index a693642..e6f92b2 100644 (file)
@@ -1,6 +1,6 @@
 (* HCoop Domtool (http://hcoop.sourceforge.net/)
  * Copyright (c) 2006-2007, Adam Chlipala
- * Copyright (c) 2011,2012,2013 Clinton Ebadi
+ * Copyright (c) 2011,2012,2013,2014 Clinton Ebadi
  *
  * This program is free software; you can redistribute it and/or
  * modify it under the terms of the GNU General Public License
@@ -43,31 +43,46 @@ fun parseRules () =
        fun parsePorts ports =
            List.mapPartial Int.fromString (String.fields (fn ch => ch = #",") ports)
            (* Just drop bad ports for now *)
-                      
+
+       fun parseNodes nodes = String.fields (fn ch => ch = #",") nodes
+
        fun loop parsedRules =
            case TextIO.inputLine inf of
                NONE => parsedRules
              | SOME line =>
                case String.tokens Char.isSpace line of
-                   node :: uname :: rest =>
-                   (case rest of
-                       "Client" :: ports :: hosts => loop ((User uname, FirewallNode node, Client (parsePorts ports, hosts)) :: parsedRules)
-                     | "Server" :: ports :: hosts => loop ((User uname, FirewallNode node, Server (parsePorts ports, hosts)) :: parsedRules)
-                     | ["ProxiedServer", ports] => loop ((User uname, FirewallNode node, ProxiedServer (parsePorts ports)) :: parsedRules)
-                     | ["LocalServer", ports] => loop ((User uname, FirewallNode node, LocalServer (parsePorts ports)) :: parsedRules)
-                     | _ => (print "Invalid config line\n"; loop parsedRules))
+                   nodes :: uname :: rest =>
+                   let
+                       val nodes = parseNodes nodes
+                   in
+                       case rest of
+                           "Client" :: ports :: hosts => loop (map (fn node => (User uname, FirewallNode node, Client (parsePorts ports, hosts))) nodes) @ parsedRules
+                         | "Server" :: ports :: hosts => loop (map (fn node => (User uname, FirewallNode node, Server (parsePorts ports, hosts))) nodes) @ parsedRules
+                         | ["ProxiedServer", ports]   => loop (map (fn node => (User uname, FirewallNode node, ProxiedServer (parsePorts ports))) nodes) @ parsedRules
+                         | ["LocalServer", ports]     => loop (map (fn node => (User uname, FirewallNode node, LocalServer (parsePorts ports))) nodes)   @ parsedRules
+                         | _ => (print "Invalid config line\n"; loop parsedRules)
+                   end
                  | _ => loop parsedRules
     in
        loop []
     end
 
-fun query uname =
+fun formatQueryRule (Client (ports, hosts)) =
+    "Client " ^ String.concatWith "," (map Int.toString ports) ^ " " ^ String.concatWith " " hosts
+  | formatQueryRule  (Server (ports, hosts)) =
+    "Server " ^ String.concatWith "," (map Int.toString ports) ^ " " ^ String.concatWith " " hosts
+  | formatQueryRule (ProxiedServer ports) =
+    "ProxiedServer " ^ String.concatWith "," (map Int.toString ports)
+  | formatQueryRule (LocalServer ports) =
+    "LocalServer " ^ String.concatWith "," (map Int.toString ports)
+
+fun query (node, uname) =
     (* completely broken *)
     let
        val rules = parseRules ()
     in
-       (* map (fn (_, FirewallNode n, r) => (n, r)) (List.filter (fn (User u, _, _) => u = uname) rules) *)
-       ["broken"]
+       map (fn (_, _, r) => formatQueryRule r)
+           (List.filter (fn (User u, FirewallNode n, _) => u = uname andalso n = node) rules)
     end
 
 fun formatPorts ports = "(" ^ String.concatWith " " (map Int.toString ports) ^ ")"
@@ -76,10 +91,12 @@ fun formatHosts hosts = "(" ^ String.concatWith " " hosts ^ ")"
 fun formatOutputRule (Client (ports, hosts)) = "dport " ^ formatPorts ports ^ (case hosts of 
                                                                                 [] => ""
                                                                               | _ => " daddr " ^ formatHosts hosts) ^ " ACCEPT;"
+  | formatOutputRule _ = ""
 
 fun formatInputRule (Server (ports, hosts)) = "dport " ^ formatPorts ports ^ (case hosts of 
                                                                                  [] => ""
                                                                                | _ => " saddr " ^ formatHosts hosts) ^ " ACCEPT;"
+  | formatInputRule _ = ""
 
 type ferm_lines = { input_rules : (string list) DataStructures.StringMap.map,
                    output_rules : (string list) DataStructures.StringMap.map } 
@@ -151,9 +168,13 @@ fun generateFirewallConfig rules =
            (* We can't match the user when listening; SELinux or
               similar would let us manage this with better
               granularity.*)
-           (TextIO.output (users_tcp_in_conf, "proto tcp {\n");
-            TextIO.output (users_tcp_in_conf, concat lines);
-            TextIO.output (users_tcp_in_conf, "\n}\n\n"))
+           let
+               val _ = SysWord.toInt (Posix.ProcEnv.uidToWord (Posix.SysDB.Passwd.uid (Posix.SysDB.getpwnam uname)))
+           in
+               TextIO.output (users_tcp_in_conf, "proto tcp {\n");
+               TextIO.output (users_tcp_in_conf, concat lines);
+               TextIO.output (users_tcp_in_conf, "\n}\n\n")
+           end handle OS.SysErr _ => print "Invalid user in firewall config, skipping.\n" (* no sense in opening ports for bad users *)                
 
        fun writeUserOutRules (uname, lines) =
            let