Overhaul fwtool
[hcoop/domtool2.git] / src / plugins / firewall.sml
index 378e127..a693642 100644 (file)
@@ -1,6 +1,6 @@
 (* HCoop Domtool (http://hcoop.sourceforge.net/)
  * Copyright (c) 2006-2007, Adam Chlipala
- * Copyright (c) 2011 Clinton Ebadi
+ * Copyright (c) 2011,2012,2013 Clinton Ebadi
  *
  * This program is free software; you can redistribute it and/or
  * modify it under the terms of the GNU General Public License
 
 structure Firewall :> FIREWALL = struct
 
-type firewall_rules = { server_rules : ((string * string) list DataStructures.StringMap.map), 
-                       client_rules : ((string * string) list DataStructures.StringMap.map)}
+datatype user = User of string
+                   
+datatype fwnode = FirewallNode of string
+
+datatype fwrule = Client of int list * string list
+               | Server of int list * string list
+               | ProxiedServer of int list
+               | LocalServer of int list
+
+type firewall_rules = (user * fwnode * fwrule) list
 
 structure StringMap = DataStructures.StringMap
 
 fun parseRules () =
     let
        val inf = TextIO.openIn Config.Firewall.firewallRules
-       val out_lines = ref StringMap.empty
-       val in_lines = ref StringMap.empty
-
-       fun confLine r (node, uname, line) =
-           let
-               val line = (node, String.concat ["\t", line, "\n"])
-               val lines = case StringMap.find (!r, uname) of
-                               NONE => []
-                             | SOME lines => lines
-           in
-               r := StringMap.insert (!r, uname, line :: lines)
-           end
-
-       val confLine_in = confLine in_lines
-       val confLine_out = confLine out_lines
 
        fun parsePorts ports =
-           case String.fields (fn ch => ch = #",") ports of
-               [pp] => pp
-             | pps => String.concat ["(", String.concatWith " " pps, ")"]
+           List.mapPartial Int.fromString (String.fields (fn ch => ch = #",") ports)
+           (* Just drop bad ports for now *)
                       
-       fun parseHosts addr hosts =
-           case hosts of
-               [] => ""
-             | [host] => String.concat [" ", addr, " ", host]
-             | _ => String.concat [" ", addr, " (", String.concatWith " " hosts, ")"]
-
-       fun loop () =
+       fun loop parsedRules =
            case TextIO.inputLine inf of
-               NONE => ()
+               NONE => parsedRules
              | SOME line =>
                case String.tokens Char.isSpace line of
                    node :: uname :: rest =>
                    (case rest of
-                       "Client" :: ports :: hosts =>
-                       confLine_out (node, uname, String.concat ["dport ", parsePorts ports, parseHosts "daddr" hosts, " ACCEPT;"])
-                     | "Server" :: ports :: hosts =>
-                       confLine_in (node, uname, String.concat ["dport ", parsePorts ports, parseHosts "daddr" hosts, " ACCEPT;"])
-                     | ["LocalServer", ports] =>
-                       confLine_in (node, uname, String.concat ["saddr $WE dport ", parsePorts ports, " ACCEPT;"])
-                     | _ => print "Invalid config line\n";
-                    loop ())
-                 | _ => loop ()
-       val _ = loop ()
+                       "Client" :: ports :: hosts => loop ((User uname, FirewallNode node, Client (parsePorts ports, hosts)) :: parsedRules)
+                     | "Server" :: ports :: hosts => loop ((User uname, FirewallNode node, Server (parsePorts ports, hosts)) :: parsedRules)
+                     | ["ProxiedServer", ports] => loop ((User uname, FirewallNode node, ProxiedServer (parsePorts ports)) :: parsedRules)
+                     | ["LocalServer", ports] => loop ((User uname, FirewallNode node, LocalServer (parsePorts ports)) :: parsedRules)
+                     | _ => (print "Invalid config line\n"; loop parsedRules))
+                 | _ => loop parsedRules
     in
-       {server_rules = !in_lines, client_rules = !out_lines}
+       loop []
     end
 
 fun query uname =
+    (* completely broken *)
     let
        val rules = parseRules ()
     in
-       List.map (fn (n,r) => r ^ " #host: " ^ n) (getOpt (StringMap.find (#server_rules rules, uname), []) @ getOpt (StringMap.find (#client_rules rules, uname), []))
+       (* map (fn (_, FirewallNode n, r) => (n, r)) (List.filter (fn (User u, _, _) => u = uname) rules) *)
+       ["broken"]
     end
 
+fun formatPorts ports = "(" ^ String.concatWith " " (map Int.toString ports) ^ ")"
+fun formatHosts hosts = "(" ^ String.concatWith " " hosts ^ ")"
+
+fun formatOutputRule (Client (ports, hosts)) = "dport " ^ formatPorts ports ^ (case hosts of 
+                                                                                [] => ""
+                                                                              | _ => " daddr " ^ formatHosts hosts) ^ " ACCEPT;"
+
+fun formatInputRule (Server (ports, hosts)) = "dport " ^ formatPorts ports ^ (case hosts of 
+                                                                                 [] => ""
+                                                                               | _ => " saddr " ^ formatHosts hosts) ^ " ACCEPT;"
+
+type ferm_lines = { input_rules : (string list) DataStructures.StringMap.map,
+                   output_rules : (string list) DataStructures.StringMap.map } 
+
+fun generateNodeFermRules rules  = 
+    let
+       fun filter_node_rules rules =
+           List.filter (fn (uname, FirewallNode node, rule) => node = Slave.hostname () orelse case rule of 
+                                                                                                   ProxiedServer _ => List.exists (fn (h,_) => h = Slave.hostname ()) Config.Apache.webNodes_all
+                                                                                    | _ => false)
+                       rules
 
-fun generateFirewallConfig {server_rules, client_rules} =
+       val inputLines = ref StringMap.empty
+       val outputLines = ref StringMap.empty
+
+       fun confLine r (User uname, line) =
+           let
+               val line = "\t" ^ line ^ "\n"
+               val lines = case StringMap.find (!r, uname) of
+                               NONE => []
+                             | SOME lines => lines
+           in
+               r := StringMap.insert (!r, uname, line :: lines)
+           end
+
+       fun confLine_in (uname, rule) = confLine inputLines (uname, formatInputRule rule)
+       fun confLine_out (uname, rule) = confLine outputLines (uname, formatOutputRule rule)
+
+       fun insertConfLine (uname, ruleNode, rule) =
+           case rule of 
+               Client (ports, hosts) => confLine_out (uname, rule)
+             | Server (ports, hosts) => confLine_in (uname, rule)
+             | LocalServer ports => (insertConfLine (uname, ruleNode, Client (ports, ["127.0.0.1/8"]));
+                                     insertConfLine (uname, ruleNode, Server (ports, ["127.0.0.1/8"])))
+             | ProxiedServer ports => if (fn FirewallNode r => r) ruleNode = Slave.hostname () then
+                                          (insertConfLine (uname, ruleNode, Server (ports, ["$WEBNODES"]));
+                                           insertConfLine (uname, ruleNode, Client (ports, [(fn FirewallNode r => r) ruleNode])))
+                                      else (* we are a web server *)
+                                          (insertConfLine (uname, ruleNode, Client (ports, [(fn FirewallNode r => r) ruleNode]));
+                                           insertConfLine (User "www-data", ruleNode, Client (ports, [(fn FirewallNode r => r) ruleNode])))
+
+       val _ = map insertConfLine (filter_node_rules rules)
+    in
+       { input_rules = !inputLines,
+         output_rules = !outputLines }
+
+
+    end
+
+fun generateFirewallConfig rules =
     (* rule generation must happen on the node (mandating the even
        service users be pts users would make it possible to do on the
        server, but that's not happening any time soon) *)
     let
        val users_tcp_out_conf = TextIO.openOut (Config.Firewall.firewallDir ^ "/users_tcp_out.conf")
        val users_tcp_in_conf = TextIO.openOut (Config.Firewall.firewallDir ^ "/users_tcp_in.conf")
-       val users_conf = TextIO.openOut (Config.Firewall.firewallDir ^ "/user_chains.conf")
-
-       fun filter_node_rules lines =
-           (* filter out rules for other hosts here... really not
-           ideal, but it should work for the time being *)
-           List.map (fn (node, line) => line)
-                    (List.filter (fn (node, line) => node = Slave.hostname ()) lines)
-
-       fun write_user_tcp_conf (rules, outf, suffix) =
-           StringMap.appi (fn (uname, rules) =>
-                              let
-                                  val uid = SysWord.toInt (Posix.ProcEnv.uidToWord (Posix.SysDB.Passwd.uid (Posix.SysDB.getpwnam uname)))
-                                  val lines = filter_node_rules rules
-                              in
-                                  TextIO.output (outf, String.concat
-                                                           ["mod owner uid-owner ",
-                                                            Int.toString uid,
-                                                            " { goto user_",
-                                                            uname,
-                                                            suffix,
-                                                            "; DROP; }\n"]);
-                                  (* Is there any point to splitting the rules like this? *)
-                                  TextIO.output (users_conf,
-                                                 String.concat ("chain user_"
-                                                                :: uname
-                                                                :: suffix
-                                                                :: " proto tcp {\n"
-                                                                :: lines
-                                                                @ ["}\n\n"]))
-                              end handle OS.SysErr _ => print "Invalid user in firewall config, skipping.\n")
-                          rules
+       val user_chains_conf = TextIO.openOut (Config.Firewall.firewallDir ^ "/user_chains.conf")
+
+       val nodeFermRules = generateNodeFermRules rules
+               
+       fun write_tcp_in_conf_preamble outf = 
+           TextIO.output (outf, String.concat ["@def $WEBNODES = (",
+                                               (String.concatWith " " (List.map (fn (_, ip) => ip) 
+                                                                                (List.filter (fn (node, _) => List.exists (fn (n) => n = node) (List.map (fn (node, _) => node) (Config.Apache.webNodes_all @ Config.Apache.webNodes_admin)))
+                                                                                             Config.nodeIps))),
+                                               ");\n\n"])
+
+       fun writeUserInRules (uname, lines) = 
+           (* We can't match the user when listening; SELinux or
+              similar would let us manage this with better
+              granularity.*)
+           (TextIO.output (users_tcp_in_conf, "proto tcp {\n");
+            TextIO.output (users_tcp_in_conf, concat lines);
+            TextIO.output (users_tcp_in_conf, "\n}\n\n"))
+
+       fun writeUserOutRules (uname, lines) =
+           let
+               val uid = SysWord.toInt (Posix.ProcEnv.uidToWord (Posix.SysDB.Passwd.uid (Posix.SysDB.getpwnam uname)))
+           in
+               TextIO.output (users_tcp_out_conf, "mod owner uid-owner " ^ (Int.toString uid)
+                                                  ^ " { jump user_" ^ uname ^ "_tcp_out"
+                                                  ^ "; DROP; }\n");
+
+               TextIO.output (user_chains_conf, "chain user_" ^ uname ^ "_tcp_out"
+                                                ^ " proto tcp {\n");
+               TextIO.output (user_chains_conf, concat lines);
+               TextIO.output (user_chains_conf, "\n}\n\n")
+           end handle OS.SysErr _ => print "Invalid user in firewall config, skipping.\n"
+           
     in
-       write_user_tcp_conf (server_rules, users_tcp_in_conf, "_tcp_in");
-       write_user_tcp_conf (client_rules, users_tcp_out_conf, "_tcp_out");
+       write_tcp_in_conf_preamble (users_tcp_in_conf);
+       StringMap.appi (writeUserOutRules) (#output_rules nodeFermRules);
+       StringMap.appi (writeUserInRules) (#input_rules nodeFermRules);
 
-       TextIO.closeOut users_conf;
+       TextIO.closeOut user_chains_conf;
        TextIO.closeOut users_tcp_out_conf;
        TextIO.closeOut users_tcp_in_conf;