From efe5901a55de0e610754025b7b813968e29a2d59 Mon Sep 17 00:00:00 2001 From: strNophix Date: Thu, 7 Apr 2022 14:02:04 +0200 Subject: [PATCH] Refactored service.py and added recovery cmd --- service.py | 153 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 92 insertions(+), 61 deletions(-) diff --git a/service.py b/service.py index 3d92f23..a63dc71 100755 --- a/service.py +++ b/service.py @@ -7,7 +7,7 @@ from os import path from pathlib import Path import os import time -from typing import Any, Callable, DefaultDict, Iterable, Mapping +from typing import Any, DefaultDict, Iterable, Mapping from jinja2 import Template import itertools import subprocess as sub @@ -19,7 +19,40 @@ from dataclasses import dataclass TEMPLATE_DIR = "./templates" +#region util +def gen_pass() -> str: + alphabet = string.ascii_letters + string.digits + password = "".join(secrets.choice(alphabet) for _ in range(20)) + return password + +def check_positive(value: str): + ivalue = int(value) + if ivalue <= 0: + raise Exception("Supplied number must be >= 0") + return ivalue + + +def encode_member(member: str, mapping: Mapping[str, Any]) -> str: + return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()]) + + +def iter_ips(ip_format: str, start_octet: int): + ip_int = start_octet + ip_fmt = ip_format + while ip_int < 255: + yield ip_fmt.format(ip_int) + ip_int += 1 + +def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}): + c = Path(src).read_text() + t: Template = Template(c) + r = t.render(mapping) + Path(dest).write_text(r) + +#endregion + +#region models @dataclass class MachineResources: cpus: int @@ -45,32 +78,6 @@ class MachineResources: return MachineResources(cpus=int(cpus), mem=int(mem)) - -def gen_pass() -> str: - alphabet = string.ascii_letters + string.digits - password = "".join(secrets.choice(alphabet) for _ in range(20)) - return password - - -def check_positive(value: str): - ivalue = int(value) - if ivalue <= 0: - raise Exception("Supplied number must be >= 0") - return ivalue - - -def encode_member(member: str, mapping: Mapping[str, Any]) -> str: - return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()]) - - -def iter_ips(ip_format: str, start_octet: int): - ip_int = start_octet - ip_fmt = ip_format - while ip_int < 255: - yield ip_fmt.format(ip_int) - ip_int += 1 - - class InventoryWriter: def __init__(self, location: str) -> None: self._file_handle = Path(location) @@ -88,15 +95,9 @@ class InventoryWriter: for name, members in self._groups.items(): txt += self._build_group(name, members) + "\n\n" self._file_handle.write_text(txt, encoding="utf8") +#endregion - -def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}): - c = Path(src).read_text() - t: Template = Template(c) - r = t.render(mapping) - Path(dest).write_text(r) - - +#region CLI positional flows def list_envs(args: argparse.Namespace): try: customer_path = path.join("customers", args.customer_name, "envs") @@ -112,6 +113,17 @@ def delete_env(args: argparse.Namespace): shutil.rmtree(env_path) print(f"Deleted `{env}` from customer `{ args.customer_name}`") +def recover_env(args: argparse.Namespace): + for env in args.env_names: + env_path = path.join("customers", args.customer_name, "envs", env) + sub.call(["vagrant", "up"], cwd=env_path) + + # Artificial sleep waiting for sshd on VM to correctly start + print("Waiting on virtual machines...") + time.sleep(5) + + sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path) + print(f"Recovered `{env}` from customer `{ args.customer_name}`") def create_env(args: argparse.Namespace): if (args.num_nginx_web + args.num_nginx_lb + args.num_postgres) == 0: @@ -119,21 +131,34 @@ def create_env(args: argparse.Namespace): env_path = path.join("customers", args.customer_name, "envs", args.env_name) Path(env_path).mkdir(exist_ok=True, parents=True) + vagrant_mapping: dict[str, Any] = { + "webserver_specs": None, + "loadbalancers_specs": None, + "postgres_specs": None, + "num_webserver": args.num_nginx_web, + "num_loadbalancers": args.num_nginx_lb, + "num_postgres": args.num_postgres, + "env": args.env_name, + "customer_name": args.customer_name, + "ip_int": args.ip_int, + "ip_format": args.ip_format.replace("{}", "%d"), + } - web_specs = None if args.num_nginx_web > 0: print("\nNginx webserver resources:") web_specs = MachineResources.from_prompt() - - lb_specs = None + vagrant_mapping["webserver_specs"] = dataclasses.asdict(web_specs) + if args.num_nginx_lb > 0: print("\nNginx loadbalancer resources: ") lb_specs = MachineResources.from_prompt() + vagrant_mapping["loadbalancers_specs"] = dataclasses.asdict(lb_specs) + - psql_specs = None if args.num_postgres > 0: print("\nPostgresql machine resources: ") psql_specs = MachineResources.from_prompt() + vagrant_mapping["postgres_specs"] = dataclasses.asdict(psql_specs) # Template `ansible.cfg` src = path.join(TEMPLATE_DIR, "ansible.cfg.template") @@ -151,9 +176,8 @@ def create_env(args: argparse.Namespace): lb_ips = itertools.islice(ip_generator, args.num_nginx_lb) iw.add("loadbalancer", lb_ips) - psql_gen_pass: Callable[[str], str] = lambda x: encode_member( - x, {"psql_pass": gen_pass()} - ) + def psql_gen_pass(x: str) -> str: + return encode_member(x, {"psql_pass": gen_pass()}) psql_ips = list(itertools.islice(ip_generator, args.num_postgres)) psql_ips = map(psql_gen_pass, psql_ips) @@ -164,28 +188,16 @@ def create_env(args: argparse.Namespace): # Template `Vagrantfile` src = path.join(TEMPLATE_DIR, "Vagrantfile.template") dest = path.join(env_path, "Vagrantfile") + should_reload = Path(dest).exists() + copy_template(src=src, dest=dest, mapping=vagrant_mapping) - mapping = { - "env": args.env_name, - "customer_name": args.customer_name, - "ip_int": args.ip_int, - "ip_format": args.ip_format.replace("{}", "%d"), - "num_webserver": args.num_nginx_web, - "num_loadbalancers": args.num_nginx_lb, - "num_postgres": args.num_postgres, - "webserver_specs": dataclasses.asdict(web_specs), - "loadbalancers_specs": dataclasses.asdict(lb_specs), - "postgres_specs": dataclasses.asdict(psql_specs), - } - copy_template(src=src, dest=dest, mapping=mapping) - - # Generate .ssh + # Generate .ssh if it doesn't exist already ssh_dir = path.join(env_path, ".ssh") Path(ssh_dir).mkdir(exist_ok=True) rsa_path = path.join(ssh_dir, "id_rsa") if not Path(rsa_path).exists(): print(end="\n") - ssh_key_cmd = [ + sub.call([ "ssh-keygen", "-t", "rsa", @@ -193,24 +205,34 @@ def create_env(args: argparse.Namespace): "2048", "-f", rsa_path, - ] - sub.call(ssh_key_cmd) + ]) # Provision and configure machines + # Create VM's that do not exist yet sub.call(["vagrant", "up"], cwd=env_path) + + # Update VM's that already existed + if should_reload: + sub.call(["vagrant", "reload", "--provision"], cwd=env_path) + + # Artificial sleep waiting for sshd on VM to correctly start print("Waiting on virtual machines...") time.sleep(5) + sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path) +#endregion def main() -> int: parser = argparse.ArgumentParser() sub_parser = parser.add_subparsers() + # CLI definition for positional arg "list" list_parser = sub_parser.add_parser("list", help="list customer-owned environments") list_parser.add_argument("customer_name", type=str, help="name of the customer") list_parser.set_defaults(func=list_envs) + # CLI definition for positional arg "create" cenv_parser = sub_parser.add_parser("create", help="create a new environment") cenv_parser.add_argument("customer_name", type=str, help="name of the customer") cenv_parser.add_argument("env_name", type=str, help="name of the environment") @@ -240,14 +262,23 @@ def main() -> int: ) cenv_parser.set_defaults(func=create_env) + # CLI definition for positional arg "delete" denv_parser = sub_parser.add_parser("delete", help="delete an environment") denv_parser.add_argument("customer_name", type=str, help="name of the customer") denv_parser.add_argument( "env_names", type=str, nargs="+", help="name of one or more environments" ) - denv_parser.set_defaults(func=delete_env) + # CLI definition for positional arg "recover" + denv_parser = sub_parser.add_parser("recover", help="attempts to recover an env") + denv_parser.add_argument("customer_name", type=str, help="name of the customer") + denv_parser.add_argument( + "env_names", type=str, nargs="+", help="name of one or more environments" + ) + denv_parser.set_defaults(func=recover_env) + + # Parse args args = parser.parse_args(sys.argv[1:]) args.func(args) return 0