Refactored service.py and added recovery cmd
This commit is contained in:
parent
7ed3be574a
commit
efe5901a55
153
service.py
153
service.py
@ -7,7 +7,7 @@ from os import path
|
||||
from pathlib import Path
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Mapping
|
||||
from typing import Any, DefaultDict, Iterable, Mapping
|
||||
from jinja2 import Template
|
||||
import itertools
|
||||
import subprocess as sub
|
||||
@ -19,7 +19,40 @@ from dataclasses import dataclass
|
||||
|
||||
TEMPLATE_DIR = "./templates"
|
||||
|
||||
#region util
|
||||
def gen_pass() -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
password = "".join(secrets.choice(alphabet) for _ in range(20))
|
||||
return password
|
||||
|
||||
|
||||
def check_positive(value: str):
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise Exception("Supplied number must be >= 0")
|
||||
return ivalue
|
||||
|
||||
|
||||
def encode_member(member: str, mapping: Mapping[str, Any]) -> str:
|
||||
return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()])
|
||||
|
||||
|
||||
def iter_ips(ip_format: str, start_octet: int):
|
||||
ip_int = start_octet
|
||||
ip_fmt = ip_format
|
||||
while ip_int < 255:
|
||||
yield ip_fmt.format(ip_int)
|
||||
ip_int += 1
|
||||
|
||||
def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}):
|
||||
c = Path(src).read_text()
|
||||
t: Template = Template(c)
|
||||
r = t.render(mapping)
|
||||
Path(dest).write_text(r)
|
||||
|
||||
#endregion
|
||||
|
||||
#region models
|
||||
@dataclass
|
||||
class MachineResources:
|
||||
cpus: int
|
||||
@ -45,32 +78,6 @@ class MachineResources:
|
||||
|
||||
return MachineResources(cpus=int(cpus), mem=int(mem))
|
||||
|
||||
|
||||
def gen_pass() -> str:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
password = "".join(secrets.choice(alphabet) for _ in range(20))
|
||||
return password
|
||||
|
||||
|
||||
def check_positive(value: str):
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise Exception("Supplied number must be >= 0")
|
||||
return ivalue
|
||||
|
||||
|
||||
def encode_member(member: str, mapping: Mapping[str, Any]) -> str:
|
||||
return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()])
|
||||
|
||||
|
||||
def iter_ips(ip_format: str, start_octet: int):
|
||||
ip_int = start_octet
|
||||
ip_fmt = ip_format
|
||||
while ip_int < 255:
|
||||
yield ip_fmt.format(ip_int)
|
||||
ip_int += 1
|
||||
|
||||
|
||||
class InventoryWriter:
|
||||
def __init__(self, location: str) -> None:
|
||||
self._file_handle = Path(location)
|
||||
@ -88,15 +95,9 @@ class InventoryWriter:
|
||||
for name, members in self._groups.items():
|
||||
txt += self._build_group(name, members) + "\n\n"
|
||||
self._file_handle.write_text(txt, encoding="utf8")
|
||||
#endregion
|
||||
|
||||
|
||||
def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}):
|
||||
c = Path(src).read_text()
|
||||
t: Template = Template(c)
|
||||
r = t.render(mapping)
|
||||
Path(dest).write_text(r)
|
||||
|
||||
|
||||
#region CLI positional flows
|
||||
def list_envs(args: argparse.Namespace):
|
||||
try:
|
||||
customer_path = path.join("customers", args.customer_name, "envs")
|
||||
@ -112,6 +113,17 @@ def delete_env(args: argparse.Namespace):
|
||||
shutil.rmtree(env_path)
|
||||
print(f"Deleted `{env}` from customer `{ args.customer_name}`")
|
||||
|
||||
def recover_env(args: argparse.Namespace):
|
||||
for env in args.env_names:
|
||||
env_path = path.join("customers", args.customer_name, "envs", env)
|
||||
sub.call(["vagrant", "up"], cwd=env_path)
|
||||
|
||||
# Artificial sleep waiting for sshd on VM to correctly start
|
||||
print("Waiting on virtual machines...")
|
||||
time.sleep(5)
|
||||
|
||||
sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path)
|
||||
print(f"Recovered `{env}` from customer `{ args.customer_name}`")
|
||||
|
||||
def create_env(args: argparse.Namespace):
|
||||
if (args.num_nginx_web + args.num_nginx_lb + args.num_postgres) == 0:
|
||||
@ -119,21 +131,34 @@ def create_env(args: argparse.Namespace):
|
||||
|
||||
env_path = path.join("customers", args.customer_name, "envs", args.env_name)
|
||||
Path(env_path).mkdir(exist_ok=True, parents=True)
|
||||
vagrant_mapping: dict[str, Any] = {
|
||||
"webserver_specs": None,
|
||||
"loadbalancers_specs": None,
|
||||
"postgres_specs": None,
|
||||
"num_webserver": args.num_nginx_web,
|
||||
"num_loadbalancers": args.num_nginx_lb,
|
||||
"num_postgres": args.num_postgres,
|
||||
"env": args.env_name,
|
||||
"customer_name": args.customer_name,
|
||||
"ip_int": args.ip_int,
|
||||
"ip_format": args.ip_format.replace("{}", "%d"),
|
||||
}
|
||||
|
||||
web_specs = None
|
||||
if args.num_nginx_web > 0:
|
||||
print("\nNginx webserver resources:")
|
||||
web_specs = MachineResources.from_prompt()
|
||||
|
||||
lb_specs = None
|
||||
vagrant_mapping["webserver_specs"] = dataclasses.asdict(web_specs)
|
||||
|
||||
if args.num_nginx_lb > 0:
|
||||
print("\nNginx loadbalancer resources: ")
|
||||
lb_specs = MachineResources.from_prompt()
|
||||
vagrant_mapping["loadbalancers_specs"] = dataclasses.asdict(lb_specs)
|
||||
|
||||
|
||||
psql_specs = None
|
||||
if args.num_postgres > 0:
|
||||
print("\nPostgresql machine resources: ")
|
||||
psql_specs = MachineResources.from_prompt()
|
||||
vagrant_mapping["postgres_specs"] = dataclasses.asdict(psql_specs)
|
||||
|
||||
# Template `ansible.cfg`
|
||||
src = path.join(TEMPLATE_DIR, "ansible.cfg.template")
|
||||
@ -151,9 +176,8 @@ def create_env(args: argparse.Namespace):
|
||||
lb_ips = itertools.islice(ip_generator, args.num_nginx_lb)
|
||||
iw.add("loadbalancer", lb_ips)
|
||||
|
||||
psql_gen_pass: Callable[[str], str] = lambda x: encode_member(
|
||||
x, {"psql_pass": gen_pass()}
|
||||
)
|
||||
def psql_gen_pass(x: str) -> str:
|
||||
return encode_member(x, {"psql_pass": gen_pass()})
|
||||
|
||||
psql_ips = list(itertools.islice(ip_generator, args.num_postgres))
|
||||
psql_ips = map(psql_gen_pass, psql_ips)
|
||||
@ -164,28 +188,16 @@ def create_env(args: argparse.Namespace):
|
||||
# Template `Vagrantfile`
|
||||
src = path.join(TEMPLATE_DIR, "Vagrantfile.template")
|
||||
dest = path.join(env_path, "Vagrantfile")
|
||||
should_reload = Path(dest).exists()
|
||||
copy_template(src=src, dest=dest, mapping=vagrant_mapping)
|
||||
|
||||
mapping = {
|
||||
"env": args.env_name,
|
||||
"customer_name": args.customer_name,
|
||||
"ip_int": args.ip_int,
|
||||
"ip_format": args.ip_format.replace("{}", "%d"),
|
||||
"num_webserver": args.num_nginx_web,
|
||||
"num_loadbalancers": args.num_nginx_lb,
|
||||
"num_postgres": args.num_postgres,
|
||||
"webserver_specs": dataclasses.asdict(web_specs),
|
||||
"loadbalancers_specs": dataclasses.asdict(lb_specs),
|
||||
"postgres_specs": dataclasses.asdict(psql_specs),
|
||||
}
|
||||
copy_template(src=src, dest=dest, mapping=mapping)
|
||||
|
||||
# Generate .ssh
|
||||
# Generate .ssh if it doesn't exist already
|
||||
ssh_dir = path.join(env_path, ".ssh")
|
||||
Path(ssh_dir).mkdir(exist_ok=True)
|
||||
rsa_path = path.join(ssh_dir, "id_rsa")
|
||||
if not Path(rsa_path).exists():
|
||||
print(end="\n")
|
||||
ssh_key_cmd = [
|
||||
sub.call([
|
||||
"ssh-keygen",
|
||||
"-t",
|
||||
"rsa",
|
||||
@ -193,24 +205,34 @@ def create_env(args: argparse.Namespace):
|
||||
"2048",
|
||||
"-f",
|
||||
rsa_path,
|
||||
]
|
||||
sub.call(ssh_key_cmd)
|
||||
])
|
||||
|
||||
# Provision and configure machines
|
||||
# Create VM's that do not exist yet
|
||||
sub.call(["vagrant", "up"], cwd=env_path)
|
||||
|
||||
# Update VM's that already existed
|
||||
if should_reload:
|
||||
sub.call(["vagrant", "reload", "--provision"], cwd=env_path)
|
||||
|
||||
# Artificial sleep waiting for sshd on VM to correctly start
|
||||
print("Waiting on virtual machines...")
|
||||
time.sleep(5)
|
||||
|
||||
sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path)
|
||||
|
||||
#endregion
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
sub_parser = parser.add_subparsers()
|
||||
|
||||
# CLI definition for positional arg "list"
|
||||
list_parser = sub_parser.add_parser("list", help="list customer-owned environments")
|
||||
list_parser.add_argument("customer_name", type=str, help="name of the customer")
|
||||
list_parser.set_defaults(func=list_envs)
|
||||
|
||||
# CLI definition for positional arg "create"
|
||||
cenv_parser = sub_parser.add_parser("create", help="create a new environment")
|
||||
cenv_parser.add_argument("customer_name", type=str, help="name of the customer")
|
||||
cenv_parser.add_argument("env_name", type=str, help="name of the environment")
|
||||
@ -240,14 +262,23 @@ def main() -> int:
|
||||
)
|
||||
cenv_parser.set_defaults(func=create_env)
|
||||
|
||||
# CLI definition for positional arg "delete"
|
||||
denv_parser = sub_parser.add_parser("delete", help="delete an environment")
|
||||
denv_parser.add_argument("customer_name", type=str, help="name of the customer")
|
||||
denv_parser.add_argument(
|
||||
"env_names", type=str, nargs="+", help="name of one or more environments"
|
||||
)
|
||||
|
||||
denv_parser.set_defaults(func=delete_env)
|
||||
|
||||
# CLI definition for positional arg "recover"
|
||||
denv_parser = sub_parser.add_parser("recover", help="attempts to recover an env")
|
||||
denv_parser.add_argument("customer_name", type=str, help="name of the customer")
|
||||
denv_parser.add_argument(
|
||||
"env_names", type=str, nargs="+", help="name of one or more environments"
|
||||
)
|
||||
denv_parser.set_defaults(func=recover_env)
|
||||
|
||||
# Parse args
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
args.func(args)
|
||||
return 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user