template.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from __future__ import annotations
  2. from .variables import Variable, VariableCollection
  3. from pathlib import Path
  4. from typing import Any, Dict, List, Set, Optional, Literal
  5. from dataclasses import dataclass, field
  6. import logging
  7. import os
  8. from jinja2 import Environment, FileSystemLoader, meta
  9. import frontmatter
  10. logger = logging.getLogger(__name__)
  11. # -----------------------
  12. # SECTION: TemplateFile Class
  13. # -----------------------
  14. @dataclass
  15. class TemplateFile:
  16. """Represents a single file within a template directory."""
  17. relative_path: Path
  18. file_type: Literal['j2', 'static']
  19. output_path: Path # The path it will have in the output directory
  20. # !SECTION
  21. # -----------------------
  22. # SECTION: Metadata Class
  23. # -----------------------
  24. @dataclass
  25. class TemplateMetadata:
  26. """Represents template metadata with proper typing."""
  27. name: str
  28. description: str
  29. author: str
  30. date: str
  31. version: str
  32. module: str = ""
  33. tags: List[str] = field(default_factory=list)
  34. # files: List[str] = field(default_factory=list) # No longer needed, as TemplateFile handles this
  35. library: str = "unknown"
  36. def __init__(self, post: frontmatter.Post, library_name: str | None = None) -> None:
  37. """Initialize TemplateMetadata from frontmatter post."""
  38. # Validate metadata format first
  39. self._validate_metadata(post)
  40. # Extract metadata section
  41. metadata_section = post.metadata.get("metadata", {})
  42. self.name = metadata_section.get("name", "")
  43. self.description = metadata_section.get("description", "No description available")
  44. self.author = metadata_section.get("author", "")
  45. self.date = metadata_section.get("date", "")
  46. self.version = metadata_section.get("version", "")
  47. self.module = metadata_section.get("module", "")
  48. self.tags = metadata_section.get("tags", []) or []
  49. # self.files = metadata_section.get("files", []) or [] # No longer needed
  50. self.library = library_name or "unknown"
  51. @staticmethod
  52. def _validate_metadata(post: frontmatter.Post) -> None:
  53. """Validate that template has required 'metadata' section with all required fields."""
  54. metadata_section = post.metadata.get("metadata")
  55. if metadata_section is None:
  56. raise ValueError("Template format error: missing 'metadata' section")
  57. # Validate that metadata section has all required fields
  58. required_fields = ["name", "author", "version", "date", "description"]
  59. missing_fields = [field for field in required_fields if not metadata_section.get(field)]
  60. if missing_fields:
  61. raise ValueError(f"Template format error: missing required metadata fields: {missing_fields}")
  62. # !SECTION
  63. # -----------------------
  64. # SECTION: Template Class
  65. # -----------------------
  66. @dataclass
  67. class Template:
  68. """Represents a template directory."""
  69. def __init__(self, template_dir: Path, library_name: str) -> None:
  70. """Create a Template instance from a directory path."""
  71. logger.debug(f"Loading template from directory: {template_dir}")
  72. self.template_dir = template_dir
  73. self.id = template_dir.name
  74. self.library_name = library_name
  75. # Initialize caches for lazy loading
  76. self.__module_specs: Optional[dict] = None
  77. self.__merged_specs: Optional[dict] = None
  78. self.__jinja_env: Optional[Environment] = None
  79. self.__used_variables: Optional[Set[str]] = None
  80. self.__variables: Optional[VariableCollection] = None
  81. self.__template_files: Optional[List[TemplateFile]] = None # New attribute
  82. try:
  83. # Find and parse the main template file (template.yaml or template.yml)
  84. main_template_path = self._find_main_template_file()
  85. with open(main_template_path, "r", encoding="utf-8") as f:
  86. self._post = frontmatter.load(f) # Store post for later access to spec
  87. # Load metadata (always needed)
  88. self.metadata = TemplateMetadata(self._post, library_name)
  89. logger.debug(f"Loaded metadata: {self.metadata}")
  90. # Validate 'kind' field (always needed)
  91. self._validate_kind(self._post)
  92. # Collect file paths (relatively lightweight, needed for various lazy loads)
  93. # This will now populate self.template_files
  94. self._collect_template_files()
  95. logger.info(f"Loaded template '{self.id}' (v{self.metadata.version})")
  96. except (ValueError, FileNotFoundError) as e:
  97. logger.error(f"Error loading template from {template_dir}: {e}")
  98. raise
  99. except Exception as e:
  100. logger.error(f"An unexpected error occurred while loading template {template_dir}: {e}")
  101. raise
  102. def _find_main_template_file(self) -> Path:
  103. """Find the main template file (template.yaml or template.yml)."""
  104. for filename in ["template.yaml", "template.yml"]:
  105. path = self.template_dir / filename
  106. if path.exists():
  107. return path
  108. raise FileNotFoundError(f"Main template file (template.yaml or template.yml) not found in {self.template_dir}")
  109. def _load_module_specs(self, kind: str) -> dict:
  110. """Load specifications from the corresponding module."""
  111. if not kind:
  112. return {}
  113. try:
  114. import importlib
  115. module = importlib.import_module(f"..modules.{kind}", package=__package__)
  116. return getattr(module, 'spec', {})
  117. except Exception as e:
  118. raise ValueError(f"Error loading module specifications for kind '{kind}': {e}")
  119. def _merge_specs(self, module_specs: dict, template_specs: dict) -> dict:
  120. """Deep merge template specs with module specs."""
  121. merged_specs = {}
  122. for section_key in module_specs.keys():
  123. module_section = module_specs.get(section_key, {})
  124. template_section = template_specs.get(section_key, {})
  125. merged_section = {**module_section}
  126. for key in ['title', 'prompt', 'description', 'toggle', 'required']:
  127. if key in template_section:
  128. merged_section[key] = template_section[key]
  129. module_vars = module_section.get('vars') if isinstance(module_section.get('vars'), dict) else {}
  130. template_vars = template_section.get('vars') if isinstance(template_section.get('vars'), dict) else {}
  131. merged_section['vars'] = {**module_vars, **template_vars}
  132. merged_specs[section_key] = merged_section
  133. for section_key in template_specs.keys():
  134. if section_key not in module_specs:
  135. merged_specs[section_key] = {**template_specs[section_key]}
  136. return merged_specs
  137. def _collect_template_files(self) -> None:
  138. """Collects all TemplateFile objects in the template directory."""
  139. template_files: List[TemplateFile] = []
  140. for root, _, files in os.walk(self.template_dir):
  141. for filename in files:
  142. file_path = Path(root) / filename
  143. relative_path = file_path.relative_to(self.template_dir)
  144. # Skip the main template file
  145. if filename in ["template.yaml", "template.yml"]:
  146. continue
  147. if filename.endswith(".j2"):
  148. file_type: Literal['j2', 'static'] = 'j2'
  149. output_path = relative_path.with_suffix('') # Remove .j2 suffix
  150. else:
  151. file_type = 'static'
  152. output_path = relative_path # Static files keep their name
  153. template_files.append(TemplateFile(relative_path=relative_path, file_type=file_type, output_path=output_path))
  154. self.__template_files = template_files
  155. def _extract_all_used_variables(self) -> Set[str]:
  156. """Extract all undeclared variables from all .j2 files in the template directory."""
  157. used_variables: Set[str] = set()
  158. for template_file in self.template_files: # Iterate over TemplateFile objects
  159. if template_file.file_type == 'j2':
  160. file_path = self.template_dir / template_file.relative_path
  161. try:
  162. with open(file_path, "r", encoding="utf-8") as f:
  163. content = f.read()
  164. ast = self.jinja_env.parse(content) # Use lazy-loaded jinja_env
  165. used_variables.update(meta.find_undeclared_variables(ast))
  166. except Exception as e:
  167. logger.warning(f"Could not parse Jinja2 variables from {file_path}: {e}")
  168. return used_variables
  169. def _filter_specs_to_used(self, used_variables: set, merged_specs: dict, module_specs: dict, template_specs: dict) -> dict:
  170. """Filter specs to only include variables used in the templates."""
  171. filtered_specs = {}
  172. for section_key, section_data in merged_specs.items():
  173. if "vars" in section_data and isinstance(section_data["vars"], dict):
  174. filtered_vars = {}
  175. for var_name, var_data in section_data["vars"].items():
  176. if var_name in used_variables:
  177. module_has_var = var_name in module_specs.get(section_key, {}).get("vars", {})
  178. template_has_var = var_name in template_specs.get(section_key, {}).get("vars", {})
  179. if module_has_var and template_has_var:
  180. origin = "module -> template"
  181. elif template_has_var:
  182. origin = "template"
  183. else:
  184. origin = "module"
  185. var_data_with_origin = {**var_data, "origin": origin}
  186. filtered_vars[var_name] = var_data_with_origin
  187. if filtered_vars:
  188. filtered_specs[section_key] = {**section_data, "vars": filtered_vars}
  189. return filtered_specs
  190. # ---------------------------
  191. # SECTION: Validation Methods
  192. # ---------------------------
  193. @staticmethod
  194. def _validate_kind(post: frontmatter.Post) -> None:
  195. """Validate that template has required 'kind' field."""
  196. if not post.metadata.get("kind"):
  197. raise ValueError("Template format error: missing 'kind' field")
  198. def _validate_variable_definitions(self, used_variables: set[str], merged_specs: dict[str, Any]) -> None:
  199. """Validate that all variables used in Jinja2 content are defined in the spec."""
  200. defined_variables = set()
  201. for section_data in merged_specs.values():
  202. if "vars" in section_data and isinstance(section_data["vars"], dict):
  203. defined_variables.update(section_data["vars"].keys())
  204. undefined_variables = used_variables - defined_variables
  205. if undefined_variables:
  206. undefined_list = sorted(undefined_variables)
  207. error_msg = (
  208. f"Template validation error in '{self.id}': "
  209. f"Variables used in template content but not defined in spec: {undefined_list}\n\n"
  210. f"Please add these variables to your template's template.yaml spec. "
  211. f"Each variable must have a default value.\n\n"
  212. f"Example:\n"
  213. f"spec:\n"
  214. f" general:\n"
  215. f" vars:\n"
  216. )
  217. for var_name in undefined_list:
  218. error_msg += (
  219. f" {var_name}:\n"
  220. f" type: str\n"
  221. f" description: Description for {var_name}\n"
  222. f" default: <your_default_value_here>\n"
  223. )
  224. logger.error(error_msg)
  225. raise ValueError(error_msg)
  226. # !SECTION
  227. # ---------------------------------
  228. # SECTION: Jinja2 Rendering Methods
  229. # ---------------------------------
  230. @staticmethod
  231. def _create_jinja_env(searchpath: Path) -> Environment:
  232. """Create standardized Jinja2 environment for consistent template processing."""
  233. return Environment(
  234. loader=FileSystemLoader(searchpath),
  235. trim_blocks=True,
  236. lstrip_blocks=True,
  237. keep_trailing_newline=False,
  238. )
  239. def render(self, variables: dict[str, Any]) -> Dict[str, str]:
  240. """Render all .j2 files in the template directory."""
  241. logger.debug(f"Rendering template '{self.id}' with variables: {variables}")
  242. rendered_files = {}
  243. for template_file in self.template_files: # Iterate over TemplateFile objects
  244. if template_file.file_type == 'j2':
  245. try:
  246. template = self.jinja_env.get_template(str(template_file.relative_path)) # Use lazy-loaded jinja_env
  247. rendered_content = template.render(**variables)
  248. rendered_files[str(template_file.output_path)] = rendered_content
  249. except Exception as e:
  250. logger.error(f"Error rendering template file {template_file.relative_path}: {e}")
  251. raise
  252. elif template_file.file_type == 'static':
  253. # For static files, just read their content and add to rendered_files
  254. # This ensures static files are also part of the output dictionary
  255. file_path = self.template_dir / template_file.relative_path
  256. try:
  257. with open(file_path, "r", encoding="utf-8") as f:
  258. content = f.read()
  259. rendered_files[str(template_file.output_path)] = content
  260. except Exception as e:
  261. logger.error(f"Error reading static file {file_path}: {e}")
  262. raise
  263. return rendered_files
  264. # !SECTION
  265. # ---------------------------
  266. # SECTION: Lazy Loaded Properties
  267. # ---------------------------
  268. @property
  269. def template_files(self) -> List[TemplateFile]:
  270. if self.__template_files is None:
  271. self._collect_template_files() # Populate self.__template_files
  272. return self.__template_files
  273. @property
  274. def template_specs(self) -> dict:
  275. return self._post.metadata.get("spec", {})
  276. @property
  277. def module_specs(self) -> dict:
  278. if self.__module_specs is None:
  279. kind = self._post.metadata.get("kind")
  280. self.__module_specs = self._load_module_specs(kind)
  281. return self.__module_specs
  282. @property
  283. def merged_specs(self) -> dict:
  284. if self.__merged_specs is None:
  285. self.__merged_specs = self._merge_specs(self.module_specs, self.template_specs)
  286. return self.__merged_specs
  287. @property
  288. def jinja_env(self) -> Environment:
  289. if self.__jinja_env is None:
  290. self.__jinja_env = self._create_jinja_env(self.template_dir)
  291. return self.__jinja_env
  292. @property
  293. def used_variables(self) -> Set[str]:
  294. if self.__used_variables is None:
  295. self.__used_variables = self._extract_all_used_variables()
  296. return self.__used_variables
  297. @property
  298. def variables(self) -> VariableCollection:
  299. if self.__variables is None:
  300. # Validate that all used variables are defined
  301. self._validate_variable_definitions(self.used_variables, self.merged_specs)
  302. # Filter specs to only used variables
  303. filtered_specs = self._filter_specs_to_used(self.used_variables, self.merged_specs, self.module_specs, self.template_specs)
  304. self.__variables = VariableCollection(filtered_specs)
  305. return self.__variables