variables.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from __future__ import annotations
  2. from collections import OrderedDict
  3. from dataclasses import dataclass, field
  4. from typing import Any, Dict, List, Optional, Set
  5. from urllib.parse import urlparse
  6. import logging
  7. import re
  8. logger = logging.getLogger(__name__)
  9. # -----------------------
  10. # SECTION: Constants
  11. # -----------------------
  12. TRUE_VALUES = {"true", "1", "yes", "on"}
  13. FALSE_VALUES = {"false", "0", "no", "off"}
  14. HOSTNAME_REGEX = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9_-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9_-]{1,63}(?<!-))*$")
  15. EMAIL_REGEX = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
  16. # !SECTION
  17. # ----------------------
  18. # SECTION: Variable Class
  19. # ----------------------
  20. class Variable:
  21. """Represents a single templating variable with lightweight validation."""
  22. def __init__(self, data: dict[str, Any]) -> None:
  23. """Initialize Variable from a dictionary containing variable specification.
  24. Args:
  25. data: Dictionary containing variable specification with required 'name' key
  26. and optional keys: description, type, options, prompt, value, default, section, origin
  27. Raises:
  28. ValueError: If data is not a dict, missing 'name' key, or has invalid default value
  29. """
  30. # Validate input
  31. if not isinstance(data, dict):
  32. raise ValueError("Variable data must be a dictionary")
  33. if "name" not in data:
  34. raise ValueError("Variable data must contain 'name' key")
  35. # Initialize fields
  36. self.name: str = data["name"]
  37. self.description: Optional[str] = data.get("description") or data.get("display", "")
  38. self.type: str = data.get("type", "str")
  39. self.options: Optional[List[Any]] = data.get("options", [])
  40. self.prompt: Optional[str] = data.get("prompt")
  41. self.value: Any = data.get("value") if data.get("value") is not None else data.get("default")
  42. self.section: Optional[str] = data.get("section")
  43. self.origin: Optional[str] = data.get("origin")
  44. # Validate and convert the default/initial value if present
  45. if self.value is not None:
  46. try:
  47. self.value = self.convert(self.value)
  48. except ValueError as exc:
  49. raise ValueError(f"Invalid default for variable '{self.name}': {exc}")
  50. def validate(self, value: Any) -> None:
  51. """Validate a value based on the variable's type and constraints."""
  52. if self.type not in ["bool"] and (value is None or value == ""):
  53. raise ValueError("value cannot be empty")
  54. # -------------------------
  55. # SECTION: Type Conversion
  56. # -------------------------
  57. def convert(self, value: Any) -> Any:
  58. """Validate and convert a raw value based on the variable type."""
  59. if value is None:
  60. return None
  61. # Type conversion mapping for cleaner code
  62. converters = {
  63. "bool": self._convert_bool,
  64. "int": self._convert_int,
  65. "float": self._convert_float,
  66. "enum": self._convert_enum,
  67. "hostname": self._convert_hostname,
  68. "url": self._convert_url,
  69. "email": self._convert_email,
  70. }
  71. converter = converters.get(self.type)
  72. if converter:
  73. return converter(value)
  74. # Default to string conversion
  75. return str(value)
  76. def _convert_bool(self, value: Any) -> bool:
  77. """Convert value to boolean."""
  78. if isinstance(value, bool):
  79. return value
  80. if isinstance(value, str):
  81. lowered = value.strip().lower()
  82. if lowered in TRUE_VALUES:
  83. return True
  84. if lowered in FALSE_VALUES:
  85. return False
  86. raise ValueError("value must be a boolean (true/false)")
  87. def _convert_int(self, value: Any) -> Optional[int]:
  88. """Convert value to integer."""
  89. if isinstance(value, int):
  90. return value
  91. if isinstance(value, str) and value.strip() == "":
  92. return None
  93. try:
  94. return int(value)
  95. except (TypeError, ValueError) as exc:
  96. raise ValueError("value must be an integer") from exc
  97. def _convert_float(self, value: Any) -> Optional[float]:
  98. """Convert value to float."""
  99. if isinstance(value, float):
  100. return value
  101. if isinstance(value, str) and value.strip() == "":
  102. return None
  103. try:
  104. return float(value)
  105. except (TypeError, ValueError) as exc:
  106. raise ValueError("value must be a float") from exc
  107. def _convert_enum(self, value: Any) -> Optional[str]:
  108. """Convert value to enum option."""
  109. if value == "":
  110. return None
  111. val = str(value)
  112. if self.options and val not in self.options:
  113. raise ValueError(f"value must be one of: {', '.join(self.options)}")
  114. return val
  115. def _convert_hostname(self, value: Any) -> str:
  116. """Convert and validate hostname."""
  117. val = str(value).strip()
  118. if not val:
  119. return ""
  120. if val.lower() == "localhost":
  121. return val
  122. if not HOSTNAME_REGEX.fullmatch(val):
  123. raise ValueError("value must be a valid hostname")
  124. return val
  125. def _convert_url(self, value: Any) -> str:
  126. """Convert and validate URL."""
  127. val = str(value).strip()
  128. if not val:
  129. return ""
  130. parsed = urlparse(val)
  131. if not (parsed.scheme and parsed.netloc):
  132. raise ValueError("value must be a valid URL (include scheme and host)")
  133. return val
  134. def _convert_email(self, value: Any) -> str:
  135. """Convert and validate email."""
  136. val = str(value).strip()
  137. if not val:
  138. return ""
  139. if not EMAIL_REGEX.fullmatch(val):
  140. raise ValueError("value must be a valid email address")
  141. return val
  142. def get_typed_value(self) -> Any:
  143. """Return the stored value converted to the appropriate Python type."""
  144. return self.convert(self.value)
  145. # !SECTION
  146. # !SECTION
  147. # ----------------------------
  148. # SECTION: VariableSection Class
  149. # ----------------------------
  150. class VariableSection:
  151. """Groups variables together with shared metadata for presentation."""
  152. def __init__(self, data: dict[str, Any]) -> None:
  153. """Initialize VariableSection from a dictionary.
  154. Args:
  155. data: Dictionary containing section specification with required 'key' and 'title' keys
  156. """
  157. if not isinstance(data, dict):
  158. raise ValueError("VariableSection data must be a dictionary")
  159. if "key" not in data:
  160. raise ValueError("VariableSection data must contain 'key'")
  161. if "title" not in data:
  162. raise ValueError("VariableSection data must contain 'title'")
  163. self.key: str = data["key"]
  164. self.title: str = data["title"]
  165. self.variables: OrderedDict[str, Variable] = OrderedDict()
  166. self.prompt: Optional[str] = data.get("prompt")
  167. self.description: Optional[str] = data.get("description")
  168. self.toggle: Optional[str] = data.get("toggle")
  169. # Default "general" section to required=True, all others to required=False
  170. self.required: bool = data.get("required", data["key"] == "general")
  171. def variable_names(self) -> list[str]:
  172. return list(self.variables.keys())
  173. # !SECTION
  174. # --------------------------------
  175. # SECTION: VariableCollection Class
  176. # --------------------------------
  177. class VariableCollection:
  178. """Manages variables grouped by sections and builds Jinja context."""
  179. def __init__(self, spec: dict[str, Any]) -> None:
  180. """Initialize VariableCollection from a specification dictionary.
  181. Args:
  182. spec: Dictionary containing the complete variable specification structure
  183. Expected format (as used in compose.py):
  184. {
  185. "section_key": {
  186. "title": "Section Title",
  187. "prompt": "Optional prompt text",
  188. "toggle": "optional_toggle_var_name",
  189. "description": "Optional description",
  190. "vars": {
  191. "var_name": {
  192. "description": "Variable description",
  193. "type": "str",
  194. "default": "default_value",
  195. ...
  196. }
  197. }
  198. }
  199. }
  200. """
  201. if not isinstance(spec, dict):
  202. raise ValueError("Spec must be a dictionary")
  203. self._set: Dict[str, VariableSection] = {}
  204. # Initialize sections and their variables
  205. for section_key, section_data in spec.items():
  206. if not isinstance(section_data, dict):
  207. continue
  208. # Create section data with the key included
  209. section_init_data = {
  210. "key": section_key,
  211. "title": section_data.get("title", section_key.replace("_", " ").title()),
  212. "prompt": section_data.get("prompt"),
  213. "description": section_data.get("description"),
  214. "toggle": section_data.get("toggle"),
  215. "required": section_data.get("required", section_key == "general")
  216. }
  217. section = VariableSection(section_init_data)
  218. # Initialize variables in this section
  219. if "vars" in section_data:
  220. for var_name, var_data in section_data["vars"].items():
  221. # Add variable name to the data
  222. var_init_data = {"name": var_name, **var_data}
  223. variable = Variable(var_init_data)
  224. section.variables[var_name] = variable
  225. self._set[section_key] = section
  226. # -------------------------
  227. # SECTION: Helper Methods
  228. # -------------------------
  229. # NOTE: These helper methods reduce code duplication across module.py and prompt.py
  230. # by centralizing common variable collection operations
  231. def get_all_values(self) -> dict[str, Any]:
  232. """Get all variable values as a dictionary.
  233. Returns:
  234. Dictionary mapping variable names to their typed values
  235. """
  236. # NOTE: Eliminates the need to iterate through sections and variables manually
  237. # in module.py _extract_current_variable_values() method
  238. all_values = {}
  239. for section in self._set.values():
  240. for var_name, variable in section.variables.items():
  241. all_values[var_name] = variable.get_typed_value()
  242. return all_values
  243. def apply_overrides(self, overrides: dict[str, Any], origin_suffix: str = " -> cli") -> list[str]:
  244. """Apply multiple variable overrides at once.
  245. Args:
  246. overrides: Dictionary of variable names to values
  247. origin_suffix: Suffix to append to origins for overridden variables
  248. Returns:
  249. List of variable names that were successfully overridden
  250. """
  251. # NOTE: Replaces the complex _apply_cli_overrides() method in module.py
  252. # by centralizing override logic with proper error handling and origin tracking
  253. successful_overrides = []
  254. errors = []
  255. for var_name, value in overrides.items():
  256. try:
  257. # Find and update the variable
  258. found = False
  259. for section in self._set.values():
  260. if var_name in section.variables:
  261. variable = section.variables[var_name]
  262. # Convert and set the new value
  263. converted_value = variable.convert(value)
  264. variable.value = converted_value
  265. # Update origin to show override
  266. if variable.origin:
  267. variable.origin = variable.origin + origin_suffix
  268. else:
  269. variable.origin = origin_suffix.lstrip(" -> ")
  270. successful_overrides.append(var_name)
  271. found = True
  272. break
  273. if not found:
  274. logger.warning(f"Variable '{var_name}' not found in template")
  275. except ValueError as e:
  276. error_msg = f"Invalid override value for '{var_name}': {value} - {e}"
  277. errors.append(error_msg)
  278. logger.error(error_msg)
  279. if errors:
  280. # Log errors but don't stop the process
  281. logger.warning(f"Some CLI overrides failed: {'; '.join(errors)}")
  282. def validate_all(self) -> None:
  283. """Validate all variables in the collection, skipping disabled sections."""
  284. for section in self._set.values():
  285. # Check if the section is disabled by a toggle
  286. if section.toggle:
  287. toggle_var = section.variables.get(section.toggle)
  288. if toggle_var and not toggle_var.get_typed_value():
  289. logger.debug(f"Skipping validation for disabled section: '{section.key}'")
  290. continue # Skip this entire section
  291. for var_name, variable in section.variables.items():
  292. try:
  293. variable.validate(variable.value)
  294. except ValueError as e:
  295. raise ValueError(f"Validation failed for variable '{var_name}': {e}") from e
  296. # !SECTION
  297. # !SECTION
  298. # !SECTION