variables.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. from __future__ import annotations
  2. from collections import OrderedDict
  3. from dataclasses import dataclass, field
  4. from typing import Any, Dict, List, Optional, Set
  5. from urllib.parse import urlparse
  6. import logging
  7. import re
  8. logger = logging.getLogger(__name__)
  9. # -----------------------
  10. # SECTION: Constants
  11. # -----------------------
  12. TRUE_VALUES = {"true", "1", "yes", "on"}
  13. FALSE_VALUES = {"false", "0", "no", "off"}
  14. HOSTNAME_REGEX = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9_-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9_-]{1,63}(?<!-))*$")
  15. EMAIL_REGEX = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
  16. # !SECTION
  17. # ----------------------
  18. # SECTION: Variable Class
  19. # ----------------------
  20. class Variable:
  21. """Represents a single templating variable with lightweight validation."""
  22. def __init__(self, data: dict[str, Any]) -> None:
  23. """Initialize Variable from a dictionary containing variable specification.
  24. Args:
  25. data: Dictionary containing variable specification with required 'name' key
  26. and optional keys: description, type, options, prompt, value, default, section, origin
  27. Raises:
  28. ValueError: If data is not a dict, missing 'name' key, or has invalid default value
  29. """
  30. # Validate input
  31. if not isinstance(data, dict):
  32. raise ValueError("Variable data must be a dictionary")
  33. if "name" not in data:
  34. raise ValueError("Variable data must contain 'name' key")
  35. # Initialize fields
  36. self.name: str = data["name"]
  37. self.description: Optional[str] = data.get("description") or data.get("display", "")
  38. self.type: str = data.get("type", "str")
  39. self.options: Optional[List[Any]] = data.get("options", [])
  40. self.prompt: Optional[str] = data.get("prompt")
  41. self.value: Any = data.get("value") if data.get("value") is not None else data.get("default")
  42. self.section: Optional[str] = data.get("section")
  43. self.origin: Optional[str] = data.get("origin")
  44. self.sensitive: bool = data.get("sensitive", False)
  45. # Validate and convert the default/initial value if present
  46. if self.value is not None:
  47. try:
  48. self.value = self.convert(self.value)
  49. except ValueError as exc:
  50. raise ValueError(f"Invalid default for variable '{self.name}': {exc}")
  51. # -------------------------
  52. # SECTION: Validation Helpers
  53. # -------------------------
  54. def _validate_not_empty(self, value: Any, converted_value: Any) -> None:
  55. """Validate that a value is not empty for non-boolean types."""
  56. if self.type not in ["bool"] and (converted_value is None or converted_value == ""):
  57. raise ValueError("value cannot be empty")
  58. def _validate_enum_option(self, value: str) -> None:
  59. """Validate that a value is in the allowed enum options."""
  60. if self.options and value not in self.options:
  61. raise ValueError(f"value must be one of: {', '.join(self.options)}")
  62. def _validate_regex_pattern(self, value: str, pattern: re.Pattern, error_msg: str) -> None:
  63. """Validate that a value matches a regex pattern."""
  64. if not pattern.fullmatch(value):
  65. raise ValueError(error_msg)
  66. def _validate_url_structure(self, parsed_url) -> None:
  67. """Validate that a parsed URL has required components."""
  68. if not (parsed_url.scheme and parsed_url.netloc):
  69. raise ValueError("value must be a valid URL (include scheme and host)")
  70. # !SECTION
  71. # -------------------------
  72. # SECTION: Type Conversion
  73. # -------------------------
  74. def convert(self, value: Any) -> Any:
  75. """Validate and convert a raw value based on the variable type."""
  76. if value is None:
  77. return None
  78. # Type conversion mapping for cleaner code
  79. converters = {
  80. "bool": self._convert_bool,
  81. "int": self._convert_int,
  82. "float": self._convert_float,
  83. "enum": self._convert_enum,
  84. "hostname": self._convert_hostname,
  85. "url": self._convert_url,
  86. "email": self._convert_email,
  87. }
  88. converter = converters.get(self.type)
  89. if converter:
  90. return converter(value)
  91. # Default to string conversion
  92. return str(value)
  93. def _convert_bool(self, value: Any) -> bool:
  94. """Convert value to boolean."""
  95. if isinstance(value, bool):
  96. return value
  97. if isinstance(value, str):
  98. lowered = value.strip().lower()
  99. if lowered in TRUE_VALUES:
  100. return True
  101. if lowered in FALSE_VALUES:
  102. return False
  103. raise ValueError("value must be a boolean (true/false)")
  104. def _convert_int(self, value: Any) -> Optional[int]:
  105. """Convert value to integer."""
  106. if isinstance(value, int):
  107. return value
  108. if isinstance(value, str) and value.strip() == "":
  109. return None
  110. try:
  111. return int(value)
  112. except (TypeError, ValueError) as exc:
  113. raise ValueError("value must be an integer") from exc
  114. def _convert_float(self, value: Any) -> Optional[float]:
  115. """Convert value to float."""
  116. if isinstance(value, float):
  117. return value
  118. if isinstance(value, str) and value.strip() == "":
  119. return None
  120. try:
  121. return float(value)
  122. except (TypeError, ValueError) as exc:
  123. raise ValueError("value must be a float") from exc
  124. def _convert_enum(self, value: Any) -> Optional[str]:
  125. """Convert value to enum option."""
  126. if value == "":
  127. return None
  128. val = str(value)
  129. self._validate_enum_option(val)
  130. return val
  131. def _convert_hostname(self, value: Any) -> str:
  132. """Convert and validate hostname."""
  133. val = str(value).strip()
  134. if not val:
  135. return ""
  136. if val.lower() != "localhost":
  137. self._validate_regex_pattern(val, HOSTNAME_REGEX, "value must be a valid hostname")
  138. return val
  139. def _convert_url(self, value: Any) -> str:
  140. """Convert and validate URL."""
  141. val = str(value).strip()
  142. if not val:
  143. return ""
  144. parsed = urlparse(val)
  145. self._validate_url_structure(parsed)
  146. return val
  147. def _convert_email(self, value: Any) -> str:
  148. """Convert and validate email."""
  149. val = str(value).strip()
  150. if not val:
  151. return ""
  152. self._validate_regex_pattern(val, EMAIL_REGEX, "value must be a valid email address")
  153. return val
  154. def get_typed_value(self) -> Any:
  155. """Return the stored value converted to the appropriate Python type."""
  156. return self.convert(self.value)
  157. # !SECTION
  158. # !SECTION
  159. # ----------------------------
  160. # SECTION: VariableSection Class
  161. # ----------------------------
  162. class VariableSection:
  163. """Groups variables together with shared metadata for presentation."""
  164. def __init__(self, data: dict[str, Any]) -> None:
  165. """Initialize VariableSection from a dictionary.
  166. Args:
  167. data: Dictionary containing section specification with required 'key' and 'title' keys
  168. """
  169. if not isinstance(data, dict):
  170. raise ValueError("VariableSection data must be a dictionary")
  171. if "key" not in data:
  172. raise ValueError("VariableSection data must contain 'key'")
  173. if "title" not in data:
  174. raise ValueError("VariableSection data must contain 'title'")
  175. self.key: str = data["key"]
  176. self.title: str = data["title"]
  177. self.variables: OrderedDict[str, Variable] = OrderedDict()
  178. self.prompt: Optional[str] = data.get("prompt")
  179. self.description: Optional[str] = data.get("description")
  180. self.toggle: Optional[str] = data.get("toggle")
  181. # Default "general" section to required=True, all others to required=False
  182. self.required: bool = data.get("required", data["key"] == "general")
  183. def variable_names(self) -> list[str]:
  184. return list(self.variables.keys())
  185. # !SECTION
  186. # --------------------------------
  187. # SECTION: VariableCollection Class
  188. # --------------------------------
  189. class VariableCollection:
  190. """Manages variables grouped by sections and builds Jinja context."""
  191. def __init__(self, spec: dict[str, Any]) -> None:
  192. """Initialize VariableCollection from a specification dictionary.
  193. Args:
  194. spec: Dictionary containing the complete variable specification structure
  195. Expected format (as used in compose.py):
  196. {
  197. "section_key": {
  198. "title": "Section Title",
  199. "prompt": "Optional prompt text",
  200. "toggle": "optional_toggle_var_name",
  201. "description": "Optional description",
  202. "vars": {
  203. "var_name": {
  204. "description": "Variable description",
  205. "type": "str",
  206. "default": "default_value",
  207. ...
  208. }
  209. }
  210. }
  211. }
  212. """
  213. if not isinstance(spec, dict):
  214. raise ValueError("Spec must be a dictionary")
  215. self._sections: Dict[str, VariableSection] = {}
  216. # NOTE: The _variable_map provides a flat, O(1) lookup for any variable by its name,
  217. # avoiding the need to iterate through sections. It stores references to the same
  218. # Variable objects contained in the _set structure.
  219. self._variable_map: Dict[str, Variable] = {}
  220. self._initialize_sections(spec)
  221. def _initialize_sections(self, spec: dict[str, Any]) -> None:
  222. """Initialize sections from the spec."""
  223. for section_key, section_data in spec.items():
  224. if not isinstance(section_data, dict):
  225. continue
  226. section = self._create_section(section_key, section_data)
  227. self._initialize_variables(section, section_data.get("vars", {}))
  228. self._sections[section_key] = section
  229. def _create_section(self, key: str, data: dict[str, Any]) -> VariableSection:
  230. """Create a VariableSection from data."""
  231. section_init_data = {
  232. "key": key,
  233. "title": data.get("title", key.replace("_", " ").title()),
  234. "prompt": data.get("prompt"),
  235. "description": data.get("description"),
  236. "toggle": data.get("toggle"),
  237. "required": data.get("required", key == "general")
  238. }
  239. return VariableSection(section_init_data)
  240. def _initialize_variables(self, section: VariableSection, vars_data: dict[str, Any]) -> None:
  241. """Initialize variables for a section."""
  242. for var_name, var_data in vars_data.items():
  243. var_init_data = {"name": var_name, **var_data}
  244. variable = Variable(var_init_data)
  245. section.variables[var_name] = variable
  246. # NOTE: Populate the direct lookup map for efficient access.
  247. self._variable_map[var_name] = variable
  248. # -------------------------
  249. # SECTION: Public API Methods
  250. # -------------------------
  251. def get_sections(self) -> Dict[str, VariableSection]:
  252. """Get all sections in the collection."""
  253. return self._sections.copy()
  254. def get_section(self, key: str) -> Optional[VariableSection]:
  255. """Get a specific section by its key."""
  256. return self._sections.get(key)
  257. def has_sections(self) -> bool:
  258. """Check if the collection has any sections."""
  259. return bool(self._sections)
  260. def get_all_values(self) -> dict[str, Any]:
  261. """Get all variable values as a dictionary."""
  262. # NOTE: This method is optimized to use the _variable_map for direct O(1) access
  263. # to each variable, which is much faster than iterating through sections.
  264. all_values = {}
  265. for var_name, variable in self._variable_map.items():
  266. all_values[var_name] = variable.get_typed_value()
  267. return all_values
  268. def get_sensitive_variables(self) -> Dict[str, Any]:
  269. """Get only the sensitive variables with their values."""
  270. return {name: var.value for name, var in self._variable_map.items() if var.sensitive and var.value}
  271. # !SECTION
  272. # -------------------------
  273. # SECTION: Helper Methods
  274. # -------------------------
  275. # NOTE: These helper methods reduce code duplication across module.py and prompt.py
  276. # by centralizing common variable collection operations
  277. def apply_overrides(self, overrides: dict[str, Any], origin_suffix: str = " -> cli") -> list[str]:
  278. """Apply multiple variable overrides at once."""
  279. # NOTE: This method uses the _variable_map for a significant performance gain,
  280. # as it allows direct O(1) lookup of variables instead of iterating
  281. # through all sections to find a match.
  282. successful_overrides = []
  283. errors = []
  284. for var_name, value in overrides.items():
  285. try:
  286. variable = self._variable_map.get(var_name)
  287. if not variable:
  288. logger.warning(f"Variable '{var_name}' not found in template")
  289. continue
  290. # Convert and set the new value
  291. converted_value = variable.convert(value)
  292. variable.value = converted_value
  293. # Update origin to show override
  294. if variable.origin:
  295. variable.origin = variable.origin + origin_suffix
  296. else:
  297. variable.origin = origin_suffix.lstrip(" -> ")
  298. successful_overrides.append(var_name)
  299. except ValueError as e:
  300. error_msg = f"Invalid override value for '{var_name}': {value} - {e}"
  301. errors.append(error_msg)
  302. logger.error(error_msg)
  303. if errors:
  304. logger.warning(f"Some CLI overrides failed: {'; '.join(errors)}")
  305. def validate_all(self) -> None:
  306. """Validate all variables in the collection, skipping disabled sections."""
  307. for section in self._sections.values():
  308. # Check if the section is disabled by a toggle
  309. if section.toggle:
  310. toggle_var = section.variables.get(section.toggle)
  311. if toggle_var and not toggle_var.get_typed_value():
  312. logger.debug(f"Skipping validation for disabled section: '{section.key}'")
  313. continue # Skip this entire section
  314. # NOTE: Skip individual variable validation since we removed the validate method
  315. # All validation now happens during conversion in the Variable.convert() method
  316. pass
  317. # !SECTION
  318. # !SECTION