template.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. from __future__ import annotations
  2. from .variables import Variable, VariableCollection
  3. from pathlib import Path
  4. from typing import Any, Dict, List, Set, Optional, Literal
  5. from dataclasses import dataclass, field
  6. import logging
  7. import os
  8. import yaml
  9. from jinja2 import Environment, FileSystemLoader, meta
  10. from jinja2 import nodes
  11. from jinja2.visitor import NodeVisitor
  12. logger = logging.getLogger(__name__)
  13. # -----------------------
  14. # SECTION: TemplateFile Class
  15. # -----------------------
  16. @dataclass
  17. class TemplateFile:
  18. """Represents a single file within a template directory."""
  19. relative_path: Path
  20. file_type: Literal['j2', 'static']
  21. output_path: Path # The path it will have in the output directory
  22. # !SECTION
  23. # -----------------------
  24. # SECTION: Metadata Class
  25. # -----------------------
  26. @dataclass
  27. class TemplateMetadata:
  28. """Represents template metadata with proper typing."""
  29. name: str
  30. description: str
  31. author: str
  32. date: str
  33. version: str
  34. module: str = ""
  35. tags: List[str] = field(default_factory=list)
  36. # files: List[str] = field(default_factory=list) # No longer needed, as TemplateFile handles this
  37. library: str = "unknown"
  38. def __init__(self, template_data: dict, library_name: str | None = None) -> None:
  39. """Initialize TemplateMetadata from parsed YAML template data.
  40. Args:
  41. template_data: Parsed YAML data from template.yaml
  42. library_name: Name of the library this template belongs to
  43. """
  44. # Validate metadata format first
  45. self._validate_metadata(template_data)
  46. # Extract metadata section
  47. metadata_section = template_data.get("metadata", {})
  48. self.name = metadata_section.get("name", "")
  49. # YAML block scalar (|) preserves a trailing newline. Remove only trailing newlines
  50. # while preserving internal newlines/formatting.
  51. raw_description = metadata_section.get("description", "")
  52. if isinstance(raw_description, str):
  53. description = raw_description.rstrip("\n")
  54. else:
  55. description = str(raw_description)
  56. self.description = description or "No description available"
  57. self.author = metadata_section.get("author", "")
  58. self.date = metadata_section.get("date", "")
  59. self.version = metadata_section.get("version", "")
  60. self.module = metadata_section.get("module", "")
  61. self.tags = metadata_section.get("tags", []) or []
  62. # self.files = metadata_section.get("files", []) or [] # No longer needed
  63. self.library = library_name or "unknown"
  64. @staticmethod
  65. def _validate_metadata(template_data: dict) -> None:
  66. """Validate that template has required 'metadata' section with all required fields.
  67. Args:
  68. template_data: Parsed YAML data from template.yaml
  69. Raises:
  70. ValueError: If metadata section is missing or incomplete
  71. """
  72. metadata_section = template_data.get("metadata")
  73. if metadata_section is None:
  74. raise ValueError("Template format error: missing 'metadata' section")
  75. # Validate that metadata section has all required fields
  76. required_fields = ["name", "author", "version", "date", "description"]
  77. missing_fields = [field for field in required_fields if not metadata_section.get(field)]
  78. if missing_fields:
  79. raise ValueError(f"Template format error: missing required metadata fields: {missing_fields}")
  80. # !SECTION
  81. # -----------------------
  82. # SECTION: Template Class
  83. # -----------------------
  84. @dataclass
  85. class Template:
  86. """Represents a template directory."""
  87. def __init__(self, template_dir: Path, library_name: str) -> None:
  88. """Create a Template instance from a directory path."""
  89. logger.debug(f"Loading template from directory: {template_dir}")
  90. self.template_dir = template_dir
  91. self.id = template_dir.name
  92. self.library_name = library_name
  93. # Initialize caches for lazy loading
  94. self.__module_specs: Optional[dict] = None
  95. self.__merged_specs: Optional[dict] = None
  96. self.__jinja_env: Optional[Environment] = None
  97. self.__used_variables: Optional[Set[str]] = None
  98. self.__variables: Optional[VariableCollection] = None
  99. self.__template_files: Optional[List[TemplateFile]] = None # New attribute
  100. try:
  101. # Find and parse the main template file (template.yaml or template.yml)
  102. main_template_path = self._find_main_template_file()
  103. with open(main_template_path, "r", encoding="utf-8") as f:
  104. # Load all YAML documents (handles templates with empty lines before ---)
  105. documents = list(yaml.safe_load_all(f))
  106. # Filter out None/empty documents and get the first non-empty one
  107. valid_docs = [doc for doc in documents if doc is not None]
  108. if not valid_docs:
  109. raise ValueError("Template file contains no valid YAML data")
  110. if len(valid_docs) > 1:
  111. logger.warning(f"Template file contains multiple YAML documents, using the first one")
  112. self._template_data = valid_docs[0]
  113. # Validate template data
  114. if not isinstance(self._template_data, dict):
  115. raise ValueError("Template file must contain a valid YAML dictionary")
  116. # Load metadata (always needed)
  117. self.metadata = TemplateMetadata(self._template_data, library_name)
  118. logger.debug(f"Loaded metadata: {self.metadata}")
  119. # Validate 'kind' field (always needed)
  120. self._validate_kind(self._template_data)
  121. # Collect file paths (relatively lightweight, needed for various lazy loads)
  122. # This will now populate self.template_files
  123. self._collect_template_files()
  124. logger.info(f"Loaded template '{self.id}' (v{self.metadata.version})")
  125. except (ValueError, FileNotFoundError) as e:
  126. logger.error(f"Error loading template from {template_dir}: {e}")
  127. raise
  128. except Exception as e:
  129. logger.error(f"An unexpected error occurred while loading template {template_dir}: {e}")
  130. raise
  131. def _find_main_template_file(self) -> Path:
  132. """Find the main template file (template.yaml or template.yml)."""
  133. for filename in ["template.yaml", "template.yml"]:
  134. path = self.template_dir / filename
  135. if path.exists():
  136. return path
  137. raise FileNotFoundError(f"Main template file (template.yaml or template.yml) not found in {self.template_dir}")
  138. def _load_module_specs(self, kind: str) -> dict:
  139. """Load specifications from the corresponding module."""
  140. if not kind:
  141. return {}
  142. try:
  143. import importlib
  144. module = importlib.import_module(f"..modules.{kind}", package=__package__)
  145. return getattr(module, 'spec', {})
  146. except Exception as e:
  147. raise ValueError(f"Error loading module specifications for kind '{kind}': {e}")
  148. def _merge_specs(self, module_specs: dict, template_specs: dict) -> dict:
  149. """Deep merge template specs with module specs using VariableCollection.
  150. Uses VariableCollection's native merge() method for consistent merging logic.
  151. Module specs are base, template specs override with origin tracking.
  152. """
  153. # Create VariableCollection from module specs (base)
  154. module_collection = VariableCollection(module_specs) if module_specs else VariableCollection({})
  155. # Set origin for module variables
  156. for section in module_collection.get_sections().values():
  157. for variable in section.variables.values():
  158. if not variable.origin:
  159. variable.origin = "module"
  160. # Merge template specs into module specs (template overrides)
  161. if template_specs:
  162. merged_collection = module_collection.merge(template_specs, origin="template")
  163. else:
  164. merged_collection = module_collection
  165. # Convert back to dict format
  166. merged_spec = {}
  167. for section_key, section in merged_collection.get_sections().items():
  168. merged_spec[section_key] = section.to_dict()
  169. return merged_spec
  170. def _collect_template_files(self) -> None:
  171. """Collects all TemplateFile objects in the template directory."""
  172. template_files: List[TemplateFile] = []
  173. for root, _, files in os.walk(self.template_dir):
  174. for filename in files:
  175. file_path = Path(root) / filename
  176. relative_path = file_path.relative_to(self.template_dir)
  177. # Skip the main template file
  178. if filename in ["template.yaml", "template.yml"]:
  179. continue
  180. if filename.endswith(".j2"):
  181. file_type: Literal['j2', 'static'] = 'j2'
  182. output_path = relative_path.with_suffix('') # Remove .j2 suffix
  183. else:
  184. file_type = 'static'
  185. output_path = relative_path # Static files keep their name
  186. template_files.append(TemplateFile(relative_path=relative_path, file_type=file_type, output_path=output_path))
  187. self.__template_files = template_files
  188. def _extract_all_used_variables(self) -> Set[str]:
  189. """Extract all undeclared variables from all .j2 files in the template directory.
  190. Raises:
  191. ValueError: If any Jinja2 template has syntax errors
  192. """
  193. used_variables: Set[str] = set()
  194. syntax_errors = []
  195. for template_file in self.template_files: # Iterate over TemplateFile objects
  196. if template_file.file_type == 'j2':
  197. file_path = self.template_dir / template_file.relative_path
  198. try:
  199. with open(file_path, "r", encoding="utf-8") as f:
  200. content = f.read()
  201. ast = self.jinja_env.parse(content) # Use lazy-loaded jinja_env
  202. used_variables.update(meta.find_undeclared_variables(ast))
  203. except Exception as e:
  204. # Collect syntax errors instead of just warning
  205. relative_path = file_path.relative_to(self.template_dir)
  206. syntax_errors.append(f" - {relative_path}: {e}")
  207. # Raise error if any syntax errors were found
  208. if syntax_errors:
  209. error_msg = (
  210. f"Jinja2 syntax errors found in template '{self.id}':\n" +
  211. "\n".join(syntax_errors) +
  212. "\n\nPlease fix the syntax errors in the template files."
  213. )
  214. logger.error(error_msg)
  215. raise ValueError(error_msg)
  216. return used_variables
  217. def _extract_jinja_default_values(self) -> dict[str, object]:
  218. """Scan all .j2 files and extract literal arguments to the `default` filter.
  219. Returns a mapping var_name -> literal_value for simple cases like
  220. {{ var | default("value") }} or {{ var | default(123) }}.
  221. This does not attempt to evaluate complex expressions.
  222. """
  223. defaults: dict[str, object] = {}
  224. class _DefaultVisitor(NodeVisitor):
  225. def __init__(self):
  226. self.found: dict[str, object] = {}
  227. def visit_Filter(self, node: nodes.Filter) -> None: # type: ignore[override]
  228. try:
  229. if getattr(node, 'name', None) == 'default' and node.args:
  230. # target variable name when filter is applied directly to a Name
  231. target = None
  232. if isinstance(node.node, nodes.Name):
  233. target = node.node.name
  234. # first arg literal
  235. first = node.args[0]
  236. if isinstance(first, nodes.Const) and target:
  237. self.found[target] = first.value
  238. except Exception:
  239. # Be resilient to unexpected node shapes
  240. pass
  241. # continue traversal
  242. self.generic_visit(node)
  243. visitor = _DefaultVisitor()
  244. for template_file in self.template_files:
  245. if template_file.file_type != 'j2':
  246. continue
  247. file_path = self.template_dir / template_file.relative_path
  248. try:
  249. with open(file_path, 'r', encoding='utf-8') as f:
  250. content = f.read()
  251. ast = self.jinja_env.parse(content)
  252. visitor.visit(ast)
  253. except Exception:
  254. # skip failures - this extraction is best-effort only
  255. continue
  256. return visitor.found
  257. def _filter_specs_to_used(self, used_variables: set, merged_specs: dict, module_specs: dict, template_specs: dict) -> dict:
  258. """Filter specs to only include variables used in templates using VariableCollection.
  259. Uses VariableCollection's native filter_to_used() method.
  260. Keeps sensitive variables only if they're defined in the template spec or actually used.
  261. """
  262. # Build set of variables explicitly defined in template spec
  263. template_defined_vars = set()
  264. for section_data in (template_specs or {}).values():
  265. if isinstance(section_data, dict) and 'vars' in section_data:
  266. template_defined_vars.update(section_data['vars'].keys())
  267. # Create VariableCollection from merged specs
  268. merged_collection = VariableCollection(merged_specs)
  269. # Filter to only used variables (and sensitive ones that are template-defined)
  270. # We keep sensitive variables that are either:
  271. # 1. Actually used in template files, OR
  272. # 2. Explicitly defined in the template spec (even if not yet used)
  273. variables_to_keep = used_variables | template_defined_vars
  274. filtered_collection = merged_collection.filter_to_used(variables_to_keep, keep_sensitive=False)
  275. # Convert back to dict format
  276. filtered_specs = {}
  277. for section_key, section in filtered_collection.get_sections().items():
  278. filtered_specs[section_key] = section.to_dict()
  279. return filtered_specs
  280. # ---------------------------
  281. # SECTION: Validation Methods
  282. # ---------------------------
  283. @staticmethod
  284. def _validate_kind(template_data: dict) -> None:
  285. """Validate that template has required 'kind' field.
  286. Args:
  287. template_data: Parsed YAML data from template.yaml
  288. Raises:
  289. ValueError: If 'kind' field is missing
  290. """
  291. if not template_data.get("kind"):
  292. raise ValueError("Template format error: missing 'kind' field")
  293. def _validate_variable_definitions(self, used_variables: set[str], merged_specs: dict[str, Any]) -> None:
  294. """Validate that all variables used in Jinja2 content are defined in the spec."""
  295. defined_variables = set()
  296. for section_data in merged_specs.values():
  297. if "vars" in section_data and isinstance(section_data["vars"], dict):
  298. defined_variables.update(section_data["vars"].keys())
  299. undefined_variables = used_variables - defined_variables
  300. if undefined_variables:
  301. undefined_list = sorted(undefined_variables)
  302. error_msg = (
  303. f"Template validation error in '{self.id}': "
  304. f"Variables used in template content but not defined in spec: {undefined_list}\n\n"
  305. f"Please add these variables to your template's template.yaml spec. "
  306. f"Each variable must have a default value.\n\n"
  307. f"Example:\n"
  308. f"spec:\n"
  309. f" general:\n"
  310. f" vars:\n"
  311. )
  312. for var_name in undefined_list:
  313. error_msg += (
  314. f" {var_name}:\n"
  315. f" type: str\n"
  316. f" description: Description for {var_name}\n"
  317. f" default: <your_default_value_here>\n"
  318. )
  319. logger.error(error_msg)
  320. raise ValueError(error_msg)
  321. # !SECTION
  322. # ---------------------------------
  323. # SECTION: Jinja2 Rendering Methods
  324. # ---------------------------------
  325. @staticmethod
  326. def _create_jinja_env(searchpath: Path) -> Environment:
  327. """Create standardized Jinja2 environment for consistent template processing.
  328. Includes custom filters for generating random values:
  329. - random_string(length): Generate random alphanumeric string
  330. - random_hex(length): Generate random hexadecimal string
  331. - random_base64(length): Generate random base64 string
  332. - random_uuid: Generate a UUID4
  333. """
  334. import secrets
  335. import string
  336. import base64
  337. import uuid
  338. env = Environment(
  339. loader=FileSystemLoader(searchpath),
  340. trim_blocks=True,
  341. lstrip_blocks=True,
  342. keep_trailing_newline=False,
  343. )
  344. # Add custom filters for generating random values
  345. def random_string(value: str = '', length: int = 32) -> str:
  346. """Generate a random alphanumeric string of specified length.
  347. Usage: {{ '' | random_string(64) }} or {{ 'ignored' | random_string(32) }}
  348. """
  349. alphabet = string.ascii_letters + string.digits
  350. return ''.join(secrets.choice(alphabet) for _ in range(length))
  351. def pwgen(value: str = '', length: int = 50) -> str:
  352. """Generate a secure random string (mimics pwgen -s).
  353. Default length is 50 (matching Authentik recommendation).
  354. Usage: {{ '' | pwgen }} or {{ '' | pwgen(64) }}
  355. """
  356. alphabet = string.ascii_letters + string.digits
  357. return ''.join(secrets.choice(alphabet) for _ in range(length))
  358. def random_hex(value: str = '', length: int = 32) -> str:
  359. """Generate a random hexadecimal string of specified length.
  360. Usage: {{ '' | random_hex(64) }}
  361. """
  362. return secrets.token_hex(length // 2)
  363. def random_base64(value: str = '', length: int = 32) -> str:
  364. """Generate a random base64 string of specified length.
  365. Usage: {{ '' | random_base64(64) }}
  366. """
  367. num_bytes = (length * 3) // 4 # Convert length to approximate bytes
  368. return base64.b64encode(secrets.token_bytes(num_bytes)).decode('utf-8')[:length]
  369. def random_uuid(value: str = '') -> str:
  370. """Generate a random UUID4.
  371. Usage: {{ '' | random_uuid }}
  372. """
  373. return str(uuid.uuid4())
  374. # Register filters
  375. env.filters['random_string'] = random_string
  376. env.filters['pwgen'] = pwgen
  377. env.filters['random_hex'] = random_hex
  378. env.filters['random_base64'] = random_base64
  379. env.filters['random_uuid'] = random_uuid
  380. return env
  381. def render(self, variables: VariableCollection) -> Dict[str, str]:
  382. """Render all .j2 files in the template directory."""
  383. # Use get_satisfied_values() to exclude variables from sections with unsatisfied dependencies
  384. variable_values = variables.get_satisfied_values()
  385. logger.debug(f"Rendering template '{self.id}' with variables: {variable_values}")
  386. rendered_files = {}
  387. for template_file in self.template_files: # Iterate over TemplateFile objects
  388. if template_file.file_type == 'j2':
  389. try:
  390. template = self.jinja_env.get_template(str(template_file.relative_path)) # Use lazy-loaded jinja_env
  391. rendered_content = template.render(**variable_values)
  392. # Sanitize the rendered content to remove excessive blank lines
  393. rendered_content = self._sanitize_content(rendered_content, template_file.output_path)
  394. rendered_files[str(template_file.output_path)] = rendered_content
  395. except Exception as e:
  396. logger.error(f"Error rendering template file {template_file.relative_path}: {e}")
  397. raise
  398. elif template_file.file_type == 'static':
  399. # For static files, just read their content and add to rendered_files
  400. # This ensures static files are also part of the output dictionary
  401. file_path = self.template_dir / template_file.relative_path
  402. try:
  403. with open(file_path, "r", encoding="utf-8") as f:
  404. content = f.read()
  405. rendered_files[str(template_file.output_path)] = content
  406. except Exception as e:
  407. logger.error(f"Error reading static file {file_path}: {e}")
  408. raise
  409. return rendered_files
  410. def _sanitize_content(self, content: str, file_path: Path) -> str:
  411. """Sanitize rendered content by removing excessive blank lines.
  412. This function:
  413. - Reduces multiple consecutive blank lines to a maximum of one blank line
  414. - Preserves file structure and readability
  415. - Removes trailing whitespace from lines
  416. - Ensures file ends with a single newline
  417. Args:
  418. content: The rendered content to sanitize
  419. file_path: Path to the output file (used for file-type detection)
  420. Returns:
  421. Sanitized content with cleaned up blank lines
  422. """
  423. if not content:
  424. return content
  425. # Split content into lines
  426. lines = content.split('\n')
  427. sanitized_lines = []
  428. blank_line_count = 0
  429. for line in lines:
  430. # Remove trailing whitespace from the line
  431. cleaned_line = line.rstrip()
  432. # Check if this is a blank line
  433. if not cleaned_line:
  434. blank_line_count += 1
  435. # Only keep the first blank line in a sequence
  436. if blank_line_count == 1:
  437. sanitized_lines.append('')
  438. else:
  439. # Reset counter when we hit a non-blank line
  440. blank_line_count = 0
  441. sanitized_lines.append(cleaned_line)
  442. # Join lines back together
  443. result = '\n'.join(sanitized_lines)
  444. # Remove leading blank lines
  445. result = result.lstrip('\n')
  446. # Ensure file ends with exactly one newline
  447. result = result.rstrip('\n') + '\n'
  448. return result
  449. def mask_sensitive_values(self, rendered_files: Dict[str, str], variables: VariableCollection) -> Dict[str, str]:
  450. """Mask sensitive values in rendered files using Variable's native masking."""
  451. masked_files = {}
  452. # Get all variables (not just sensitive ones) to use their native get_display_value()
  453. for file_path, content in rendered_files.items():
  454. # Iterate through all sections and variables
  455. for section in variables.get_sections().values():
  456. for variable in section.variables.values():
  457. if variable.sensitive and variable.value:
  458. # Use variable's native masking - always returns "********" for sensitive vars
  459. masked_value = variable.get_display_value(mask_sensitive=True)
  460. content = content.replace(str(variable.value), masked_value)
  461. masked_files[file_path] = content
  462. return masked_files
  463. # !SECTION
  464. # ---------------------------
  465. # SECTION: Lazy Loaded Properties
  466. # ---------------------------
  467. @property
  468. def template_files(self) -> List[TemplateFile]:
  469. if self.__template_files is None:
  470. self._collect_template_files() # Populate self.__template_files
  471. return self.__template_files
  472. @property
  473. def template_specs(self) -> dict:
  474. """Get the spec section from template YAML data."""
  475. return self._template_data.get("spec", {})
  476. @property
  477. def module_specs(self) -> dict:
  478. """Get the spec from the module definition."""
  479. if self.__module_specs is None:
  480. kind = self._template_data.get("kind")
  481. self.__module_specs = self._load_module_specs(kind)
  482. return self.__module_specs
  483. @property
  484. def merged_specs(self) -> dict:
  485. if self.__merged_specs is None:
  486. self.__merged_specs = self._merge_specs(self.module_specs, self.template_specs)
  487. return self.__merged_specs
  488. @property
  489. def jinja_env(self) -> Environment:
  490. if self.__jinja_env is None:
  491. self.__jinja_env = self._create_jinja_env(self.template_dir)
  492. return self.__jinja_env
  493. @property
  494. def used_variables(self) -> Set[str]:
  495. if self.__used_variables is None:
  496. self.__used_variables = self._extract_all_used_variables()
  497. return self.__used_variables
  498. @property
  499. def variables(self) -> VariableCollection:
  500. if self.__variables is None:
  501. # Validate that all used variables are defined
  502. self._validate_variable_definitions(self.used_variables, self.merged_specs)
  503. # Filter specs to only used variables
  504. filtered_specs = self._filter_specs_to_used(self.used_variables, self.merged_specs, self.module_specs, self.template_specs)
  505. # Best-effort: extract literal defaults from Jinja `default()` filter and
  506. # merge them into the filtered_specs when no default exists there.
  507. try:
  508. jinja_defaults = self._extract_jinja_default_values()
  509. for section_key, section_data in filtered_specs.items():
  510. # Guard against None from empty YAML sections
  511. vars_dict = section_data.get('vars') or {}
  512. for var_name, var_data in vars_dict.items():
  513. if 'default' not in var_data or var_data.get('default') in (None, ''):
  514. if var_name in jinja_defaults:
  515. var_data['default'] = jinja_defaults[var_name]
  516. except Exception:
  517. # keep behavior stable on any extraction errors
  518. pass
  519. self.__variables = VariableCollection(filtered_specs)
  520. # Sort sections: required first, then enabled, then disabled
  521. self.__variables.sort_sections()
  522. return self.__variables