#!/usr/bin/env python3 from __future__ import annotations import argparse import dataclasses import sys from os import path from pathlib import Path import os import time from typing import Any, DefaultDict, Iterable, Mapping from jinja2 import Template import itertools import subprocess as sub import shutil import string import secrets from dataclasses import dataclass TEMPLATE_DIR = "./templates" #region util def gen_pass(pass_length: int = 20) -> str: """Generates a simple password with the provided length Args: pass_length (int, optional): Length of password. Defaults to 20. Returns: str: The generated password. """ alphabet = string.ascii_letters + string.digits password = "".join(secrets.choice(alphabet) for _ in range(pass_length)) return password def check_positive(value: str) -> int: """`argparse` type helper to check whether a given input is a positive integer. Args: value (str): Input to validate Raises: Exception: Input is neither a decimal or positive Returns: int: The parsed input if valid """ ivalue = int(value) if ivalue < 0: raise Exception(f"Supplied number must be >= 0") return ivalue def encode_member(member: str, mapping: Mapping[str, Any]) -> str: """Encodes the member-entry of an inventory file with additional mappings. Args: member (str) mapping (Mapping[str, Any]) Returns: str: member with mapping encoded """ return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()]) def iter_ips(ip_format: str, start_octet: int): """Simple iterator for generating ip's Args: ip_format (str) start_octet (int) Yields: Yeah idk too lazy too look up what the type annotation for a generator is """ ip_int = start_octet ip_fmt = ip_format while ip_int < 255: yield ip_fmt.format(ip_int) ip_int += 1 def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}): """Templates and writes a template file using Jinja as templating engine. Args: src (str): content of the file dest (str): place to write templated file to mapping (Mapping[str, Any], optional): datamapping. Defaults to {}. """ c = Path(src).read_text() t: Template = Template(c) r = t.render(mapping) Path(dest).write_text(r) #endregion #region models @dataclass class MachineResources: cpus: int mem: int @staticmethod def from_prompt() -> "MachineResources": """Generate `MachineResources` from prompt. Raises: Exception Exception Returns: MachineResources """ cpus = input( "How many processors would you like to assign (default=1): ") if not cpus: cpus = "1" if not cpus.isdigit() or int(cpus) < 0: raise Exception("Expected a postive amount of processors") mem = input( "How many megabytes of RAM would you like to assign (default=1024): " ) if not mem: mem = "1024" if not mem.isdigit() or int(mem) < 0: raise Exception("Expected a postive amount of memory") return MachineResources(cpus=int(cpus), mem=int(mem)) class InventoryWriter: """ Helper class for generating Ansible inventory files. """ def __init__(self, location: str) -> None: self._file_handle = Path(location) self._groups: dict[str, set[str]] = DefaultDict(set) def add(self, name: str, members: Iterable[str]): self._groups[name] |= set(members) def _build_group(self, name: str, members: set[str]): fmt = f"[{name}]\n" + "\n".join(members) return fmt def flush(self): txt = "" for name, members in self._groups.items(): txt += self._build_group(name, members) + "\n\n" self._file_handle.write_text(txt, encoding="utf8") #endregion #region CLI positional flows def list_envs(args: argparse.Namespace): try: customer_path = path.join("customers", args.customer_name, "envs") print(" ".join(os.listdir(customer_path))) except FileNotFoundError: print(f"Customer `{args.customer_name}` does not exist.") def delete_env(args: argparse.Namespace): for env in args.env_names: env_path = path.join("customers", args.customer_name, "envs", env) sub.call(["vagrant", "destroy", "-f"], cwd=env_path) shutil.rmtree(env_path) print(f"Deleted `{env}` from customer `{ args.customer_name}`") def recover_env(args: argparse.Namespace): for env in args.env_names: env_path = path.join("customers", args.customer_name, "envs", env) sub.call(["vagrant", "up"], cwd=env_path) # Artificial sleep waiting for sshd on VM to correctly start print("Waiting on virtual machines...") time.sleep(5) sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path) print(f"Recovered `{env}` from customer `{ args.customer_name}`") def create_env(args: argparse.Namespace): if (args.num_nginx_web + args.num_nginx_lb + args.num_postgres) == 0: raise Exception("At least one item should be deployed") env_path = path.join("customers", args.customer_name, "envs", args.env_name) Path(env_path).mkdir(exist_ok=True, parents=True) vagrant_mapping: dict[str, Any] = { "webserver_specs": None, "loadbalancers_specs": None, "postgres_specs": None, "num_webserver": args.num_nginx_web, "num_loadbalancers": args.num_nginx_lb, "num_postgres": args.num_postgres, "env": args.env_name, "customer_name": args.customer_name, "ip_int": args.ip_int, "ip_format": args.ip_format.replace("{}", "%d"), } if args.num_nginx_web > 0: print("\nNginx webserver resources:") web_specs = MachineResources.from_prompt() vagrant_mapping["webserver_specs"] = dataclasses.asdict(web_specs) if args.num_nginx_lb > 0: print("\nNginx loadbalancer resources: ") lb_specs = MachineResources.from_prompt() vagrant_mapping["loadbalancers_specs"] = dataclasses.asdict(lb_specs) if args.num_postgres > 0: print("\nPostgresql machine resources: ") psql_specs = MachineResources.from_prompt() vagrant_mapping["postgres_specs"] = dataclasses.asdict(psql_specs) # Template `ansible.cfg` src = path.join(TEMPLATE_DIR, "ansible.cfg.template") dest = path.join(env_path, "ansible.cfg") copy_template(src=src, dest=dest) # Create inventory file inv_path = path.join(env_path, "inventory") iw = InventoryWriter(inv_path) ip_generator = iter_ips(args.ip_format, args.ip_int) web_ips = itertools.islice(ip_generator, args.num_nginx_web) iw.add("webserver", web_ips) lb_ips = itertools.islice(ip_generator, args.num_nginx_lb) iw.add("loadbalancer", lb_ips) def psql_gen_pass(x: str) -> str: return encode_member(x, {"psql_pass": gen_pass()}) psql_ips = list(itertools.islice(ip_generator, args.num_postgres)) psql_ips = map(psql_gen_pass, psql_ips) iw.add("postgresql", psql_ips) iw.flush() # Template `Vagrantfile` src = path.join(TEMPLATE_DIR, "Vagrantfile.template") dest = path.join(env_path, "Vagrantfile") should_reload = Path(dest).exists() copy_template(src=src, dest=dest, mapping=vagrant_mapping) # Generate .ssh if it doesn't exist already ssh_dir = path.join(env_path, ".ssh") Path(ssh_dir).mkdir(exist_ok=True) rsa_path = path.join(ssh_dir, "id_rsa") if not Path(rsa_path).exists(): print(end="\n") sub.call([ "ssh-keygen", "-t", "rsa", "-b", "2048", "-f", rsa_path, ]) # Provision and configure machines # Create VM's that do not exist yet sub.call(["vagrant", "up"], cwd=env_path) # Update VM's that already existed if should_reload: sub.call(["vagrant", "reload", "--provision"], cwd=env_path) # Artificial sleep waiting for sshd on VM to correctly start print("Waiting on virtual machines...") time.sleep(5) sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path) #endregion def main() -> int: parser = argparse.ArgumentParser() sub_parser = parser.add_subparsers() # CLI definition for positional arg "list" list_parser = sub_parser.add_parser( "list", help="list customer-owned environments") list_parser.add_argument("customer_name", type=str, help="name of the customer") list_parser.set_defaults(func=list_envs) # CLI definition for positional arg "create" cenv_parser = sub_parser.add_parser("create", help="create a new environment") cenv_parser.add_argument("customer_name", type=str, help="name of the customer") cenv_parser.add_argument("env_name", type=str, help="name of the environment") cenv_parser.add_argument( "--num-postgres", type=check_positive, help="number of postgres databases", default=0, ) cenv_parser.add_argument( "--num-nginx-web", type=check_positive, help="number of nginx webservers", default=0, ) cenv_parser.add_argument( "--num-nginx-lb", type=check_positive, help="number of nginx loadbalancers", default=0, ) cenv_parser.add_argument("--ip-format", type=str, help="format of ip", default="192.168.56.{}") cenv_parser.add_argument("--ip-int", type=check_positive, help="4th octet to start at", default="10") cenv_parser.set_defaults(func=create_env) # CLI definition for positional arg "delete" denv_parser = sub_parser.add_parser("delete", help="delete an environment") denv_parser.add_argument("customer_name", type=str, help="name of the customer") denv_parser.add_argument("env_names", type=str, nargs="+", help="name of one or more environments") denv_parser.set_defaults(func=delete_env) # CLI definition for positional arg "recover" denv_parser = sub_parser.add_parser("recover", help="attempts to recover an env") denv_parser.add_argument("customer_name", type=str, help="name of the customer") denv_parser.add_argument("env_names", type=str, nargs="+", help="name of one or more environments") denv_parser.set_defaults(func=recover_env) # Parse args args = parser.parse_args(sys.argv[1:]) args.func(args) return 0 if __name__ == "__main__": raise SystemExit(main())