Compare commits

...

2 Commits

2 changed files with 101 additions and 66 deletions

View File

@ -7,7 +7,7 @@ from os import path
from pathlib import Path from pathlib import Path
import os import os
import time import time
from typing import Any, Callable, DefaultDict, Iterable, Mapping from typing import Any, DefaultDict, Iterable, Mapping
from jinja2 import Template from jinja2 import Template
import itertools import itertools
import subprocess as sub import subprocess as sub
@ -19,7 +19,40 @@ from dataclasses import dataclass
TEMPLATE_DIR = "./templates" TEMPLATE_DIR = "./templates"
#region util
def gen_pass() -> str:
alphabet = string.ascii_letters + string.digits
password = "".join(secrets.choice(alphabet) for _ in range(20))
return password
def check_positive(value: str):
ivalue = int(value)
if ivalue <= 0:
raise Exception("Supplied number must be >= 0")
return ivalue
def encode_member(member: str, mapping: Mapping[str, Any]) -> str:
return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()])
def iter_ips(ip_format: str, start_octet: int):
ip_int = start_octet
ip_fmt = ip_format
while ip_int < 255:
yield ip_fmt.format(ip_int)
ip_int += 1
def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}):
c = Path(src).read_text()
t: Template = Template(c)
r = t.render(mapping)
Path(dest).write_text(r)
#endregion
#region models
@dataclass @dataclass
class MachineResources: class MachineResources:
cpus: int cpus: int
@ -45,32 +78,6 @@ class MachineResources:
return MachineResources(cpus=int(cpus), mem=int(mem)) return MachineResources(cpus=int(cpus), mem=int(mem))
def gen_pass() -> str:
alphabet = string.ascii_letters + string.digits
password = "".join(secrets.choice(alphabet) for _ in range(20))
return password
def check_positive(value: str):
ivalue = int(value)
if ivalue <= 0:
raise Exception("Supplied number must be >= 0")
return ivalue
def encode_member(member: str, mapping: Mapping[str, Any]) -> str:
return member + " " + " ".join([f"{k}={v}" for k, v in mapping.items()])
def iter_ips(ip_format: str, start_octet: int):
ip_int = start_octet
ip_fmt = ip_format
while ip_int < 255:
yield ip_fmt.format(ip_int)
ip_int += 1
class InventoryWriter: class InventoryWriter:
def __init__(self, location: str) -> None: def __init__(self, location: str) -> None:
self._file_handle = Path(location) self._file_handle = Path(location)
@ -88,15 +95,9 @@ class InventoryWriter:
for name, members in self._groups.items(): for name, members in self._groups.items():
txt += self._build_group(name, members) + "\n\n" txt += self._build_group(name, members) + "\n\n"
self._file_handle.write_text(txt, encoding="utf8") self._file_handle.write_text(txt, encoding="utf8")
#endregion
#region CLI positional flows
def copy_template(src: str, dest: str, mapping: Mapping[str, Any] = {}):
c = Path(src).read_text()
t: Template = Template(c)
r = t.render(mapping)
Path(dest).write_text(r)
def list_envs(args: argparse.Namespace): def list_envs(args: argparse.Namespace):
try: try:
customer_path = path.join("customers", args.customer_name, "envs") customer_path = path.join("customers", args.customer_name, "envs")
@ -112,6 +113,17 @@ def delete_env(args: argparse.Namespace):
shutil.rmtree(env_path) shutil.rmtree(env_path)
print(f"Deleted `{env}` from customer `{ args.customer_name}`") print(f"Deleted `{env}` from customer `{ args.customer_name}`")
def recover_env(args: argparse.Namespace):
for env in args.env_names:
env_path = path.join("customers", args.customer_name, "envs", env)
sub.call(["vagrant", "up"], cwd=env_path)
# Artificial sleep waiting for sshd on VM to correctly start
print("Waiting on virtual machines...")
time.sleep(5)
sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path)
print(f"Recovered `{env}` from customer `{ args.customer_name}`")
def create_env(args: argparse.Namespace): def create_env(args: argparse.Namespace):
if (args.num_nginx_web + args.num_nginx_lb + args.num_postgres) == 0: if (args.num_nginx_web + args.num_nginx_lb + args.num_postgres) == 0:
@ -119,21 +131,34 @@ def create_env(args: argparse.Namespace):
env_path = path.join("customers", args.customer_name, "envs", args.env_name) env_path = path.join("customers", args.customer_name, "envs", args.env_name)
Path(env_path).mkdir(exist_ok=True, parents=True) Path(env_path).mkdir(exist_ok=True, parents=True)
vagrant_mapping: dict[str, Any] = {
"webserver_specs": None,
"loadbalancers_specs": None,
"postgres_specs": None,
"num_webserver": args.num_nginx_web,
"num_loadbalancers": args.num_nginx_lb,
"num_postgres": args.num_postgres,
"env": args.env_name,
"customer_name": args.customer_name,
"ip_int": args.ip_int,
"ip_format": args.ip_format.replace("{}", "%d"),
}
web_specs = None
if args.num_nginx_web > 0: if args.num_nginx_web > 0:
print("\nNginx webserver resources:") print("\nNginx webserver resources:")
web_specs = MachineResources.from_prompt() web_specs = MachineResources.from_prompt()
vagrant_mapping["webserver_specs"] = dataclasses.asdict(web_specs)
lb_specs = None
if args.num_nginx_lb > 0: if args.num_nginx_lb > 0:
print("\nNginx loadbalancer resources: ") print("\nNginx loadbalancer resources: ")
lb_specs = MachineResources.from_prompt() lb_specs = MachineResources.from_prompt()
vagrant_mapping["loadbalancers_specs"] = dataclasses.asdict(lb_specs)
psql_specs = None
if args.num_postgres > 0: if args.num_postgres > 0:
print("\nPostgresql machine resources: ") print("\nPostgresql machine resources: ")
psql_specs = MachineResources.from_prompt() psql_specs = MachineResources.from_prompt()
vagrant_mapping["postgres_specs"] = dataclasses.asdict(psql_specs)
# Template `ansible.cfg` # Template `ansible.cfg`
src = path.join(TEMPLATE_DIR, "ansible.cfg.template") src = path.join(TEMPLATE_DIR, "ansible.cfg.template")
@ -151,9 +176,8 @@ def create_env(args: argparse.Namespace):
lb_ips = itertools.islice(ip_generator, args.num_nginx_lb) lb_ips = itertools.islice(ip_generator, args.num_nginx_lb)
iw.add("loadbalancer", lb_ips) iw.add("loadbalancer", lb_ips)
psql_gen_pass: Callable[[str], str] = lambda x: encode_member( def psql_gen_pass(x: str) -> str:
x, {"psql_pass": gen_pass()} return encode_member(x, {"psql_pass": gen_pass()})
)
psql_ips = list(itertools.islice(ip_generator, args.num_postgres)) psql_ips = list(itertools.islice(ip_generator, args.num_postgres))
psql_ips = map(psql_gen_pass, psql_ips) psql_ips = map(psql_gen_pass, psql_ips)
@ -164,28 +188,16 @@ def create_env(args: argparse.Namespace):
# Template `Vagrantfile` # Template `Vagrantfile`
src = path.join(TEMPLATE_DIR, "Vagrantfile.template") src = path.join(TEMPLATE_DIR, "Vagrantfile.template")
dest = path.join(env_path, "Vagrantfile") dest = path.join(env_path, "Vagrantfile")
should_reload = Path(dest).exists()
copy_template(src=src, dest=dest, mapping=vagrant_mapping)
mapping = { # Generate .ssh if it doesn't exist already
"env": args.env_name,
"customer_name": args.customer_name,
"ip_int": args.ip_int,
"ip_format": args.ip_format.replace("{}", "%d"),
"num_webserver": args.num_nginx_web,
"num_loadbalancers": args.num_nginx_lb,
"num_postgres": args.num_postgres,
"webserver_specs": dataclasses.asdict(web_specs),
"loadbalancers_specs": dataclasses.asdict(lb_specs),
"postgres_specs": dataclasses.asdict(psql_specs),
}
copy_template(src=src, dest=dest, mapping=mapping)
# Generate .ssh
ssh_dir = path.join(env_path, ".ssh") ssh_dir = path.join(env_path, ".ssh")
Path(ssh_dir).mkdir(exist_ok=True) Path(ssh_dir).mkdir(exist_ok=True)
rsa_path = path.join(ssh_dir, "id_rsa") rsa_path = path.join(ssh_dir, "id_rsa")
if not Path(rsa_path).exists(): if not Path(rsa_path).exists():
print(end="\n") print(end="\n")
ssh_key_cmd = [ sub.call([
"ssh-keygen", "ssh-keygen",
"-t", "-t",
"rsa", "rsa",
@ -193,24 +205,34 @@ def create_env(args: argparse.Namespace):
"2048", "2048",
"-f", "-f",
rsa_path, rsa_path,
] ])
sub.call(ssh_key_cmd)
# Provision and configure machines # Provision and configure machines
# Create VM's that do not exist yet
sub.call(["vagrant", "up"], cwd=env_path) sub.call(["vagrant", "up"], cwd=env_path)
# Update VM's that already existed
if should_reload:
sub.call(["vagrant", "reload", "--provision"], cwd=env_path)
# Artificial sleep waiting for sshd on VM to correctly start
print("Waiting on virtual machines...") print("Waiting on virtual machines...")
time.sleep(5) time.sleep(5)
sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path) sub.call(["ansible-playbook", "../../../../site.yml"], cwd=env_path)
#endregion
def main() -> int: def main() -> int:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
sub_parser = parser.add_subparsers() sub_parser = parser.add_subparsers()
# CLI definition for positional arg "list"
list_parser = sub_parser.add_parser("list", help="list customer-owned environments") list_parser = sub_parser.add_parser("list", help="list customer-owned environments")
list_parser.add_argument("customer_name", type=str, help="name of the customer") list_parser.add_argument("customer_name", type=str, help="name of the customer")
list_parser.set_defaults(func=list_envs) list_parser.set_defaults(func=list_envs)
# CLI definition for positional arg "create"
cenv_parser = sub_parser.add_parser("create", help="create a new environment") cenv_parser = sub_parser.add_parser("create", help="create a new environment")
cenv_parser.add_argument("customer_name", type=str, help="name of the customer") cenv_parser.add_argument("customer_name", type=str, help="name of the customer")
cenv_parser.add_argument("env_name", type=str, help="name of the environment") cenv_parser.add_argument("env_name", type=str, help="name of the environment")
@ -240,14 +262,23 @@ def main() -> int:
) )
cenv_parser.set_defaults(func=create_env) cenv_parser.set_defaults(func=create_env)
# CLI definition for positional arg "delete"
denv_parser = sub_parser.add_parser("delete", help="delete an environment") denv_parser = sub_parser.add_parser("delete", help="delete an environment")
denv_parser.add_argument("customer_name", type=str, help="name of the customer") denv_parser.add_argument("customer_name", type=str, help="name of the customer")
denv_parser.add_argument( denv_parser.add_argument(
"env_names", type=str, nargs="+", help="name of one or more environments" "env_names", type=str, nargs="+", help="name of one or more environments"
) )
denv_parser.set_defaults(func=delete_env) denv_parser.set_defaults(func=delete_env)
# CLI definition for positional arg "recover"
denv_parser = sub_parser.add_parser("recover", help="attempts to recover an env")
denv_parser.add_argument("customer_name", type=str, help="name of the customer")
denv_parser.add_argument(
"env_names", type=str, nargs="+", help="name of one or more environments"
)
denv_parser.set_defaults(func=recover_env)
# Parse args
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
args.func(args) args.func(args)
return 0 return 0

View File

@ -10,11 +10,10 @@ Vagrant.configure("2") do |config|
config.ssh.insert_key = false config.ssh.insert_key = false
config.ssh.private_key_path = ["./.ssh/id_rsa","~/.vagrant.d/insecure_private_key"] config.ssh.private_key_path = ["./.ssh/id_rsa","~/.vagrant.d/insecure_private_key"]
num_webserver = {{ num_webserver }}
num_loadbalancer = {{ num_loadbalancers }}
num_postgresql = {{ num_postgres }} num_postgresql = {{ num_postgres }}
(1..num_webserver).each do |nth| {% if webserver_specs is not none %}
(1..{{ num_webserver }}).each do |nth|
machine_id = "{{ customer_name }}-{{ env }}-web%d" % [nth] machine_id = "{{ customer_name }}-{{ env }}-web%d" % [nth]
machine_ip = increment_ip() machine_ip = increment_ip()
@ -33,8 +32,10 @@ Vagrant.configure("2") do |config|
end end
end end
end end
{% endif %}
(1..num_loadbalancer).each do |nth| {% if loadbalancers_specs is not none %}
(1..{{ num_loadbalancers }}).each do |nth|
machine_id = "{{ customer_name }}-{{ env }}-lb%d" % [nth] machine_id = "{{ customer_name }}-{{ env }}-lb%d" % [nth]
machine_ip = increment_ip() machine_ip = increment_ip()
@ -53,8 +54,10 @@ Vagrant.configure("2") do |config|
end end
end end
end end
{% endif %}
(1..num_postgresql).each do |nth| {% if postgres_specs is not none %}
(1..{{ num_postgres }}).each do |nth|
machine_id = "{{ customer_name }}-{{ env }}-db%d" % [nth] machine_id = "{{ customer_name }}-{{ env }}-db%d" % [nth]
machine_ip = increment_ip() machine_ip = increment_ip()
@ -73,4 +76,5 @@ Vagrant.configure("2") do |config|
end end
end end
end end
{% endif %}
end end