variables.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from __future__ import annotations
  2. from collections import OrderedDict
  3. from dataclasses import dataclass, field
  4. from typing import Any, Dict, List, Optional, Set
  5. from urllib.parse import urlparse
  6. import logging
  7. import re
  8. logger = logging.getLogger(__name__)
  9. # -----------------------
  10. # SECTION: Constants
  11. # -----------------------
  12. TRUE_VALUES = {"true", "1", "yes", "on"}
  13. FALSE_VALUES = {"false", "0", "no", "off"}
  14. HOSTNAME_REGEX = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9_-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9_-]{1,63}(?<!-))*$")
  15. EMAIL_REGEX = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
  16. # !SECTION
  17. # ----------------------
  18. # SECTION: Variable Class
  19. # ----------------------
  20. class Variable:
  21. """Represents a single templating variable with lightweight validation."""
  22. def __init__(self, data: dict[str, Any]) -> None:
  23. """Initialize Variable from a dictionary containing variable specification.
  24. Args:
  25. data: Dictionary containing variable specification with required 'name' key
  26. and optional keys: description, type, options, prompt, value, default, section, origin
  27. Raises:
  28. ValueError: If data is not a dict, missing 'name' key, or has invalid default value
  29. """
  30. # Validate input
  31. if not isinstance(data, dict):
  32. raise ValueError("Variable data must be a dictionary")
  33. if "name" not in data:
  34. raise ValueError("Variable data must contain 'name' key")
  35. # Initialize fields
  36. self.name: str = data["name"]
  37. self.description: Optional[str] = data.get("description") or data.get("display", "")
  38. self.type: str = data.get("type", "str")
  39. self.options: Optional[List[Any]] = data.get("options", [])
  40. self.prompt: Optional[str] = data.get("prompt")
  41. self.value: Any = data.get("value") if data.get("value") is not None else data.get("default")
  42. self.section: Optional[str] = data.get("section")
  43. self.origin: Optional[str] = data.get("origin")
  44. # Validate and convert the default/initial value if present
  45. if self.value is not None:
  46. try:
  47. self.value = self.convert(self.value)
  48. except ValueError as exc:
  49. raise ValueError(f"Invalid default for variable '{self.name}': {exc}")
  50. # -------------------------
  51. # SECTION: Type Conversion
  52. # -------------------------
  53. def convert(self, value: Any) -> Any:
  54. """Validate and convert a raw value based on the variable type."""
  55. if value is None:
  56. return None
  57. # Type conversion mapping for cleaner code
  58. converters = {
  59. "bool": self._convert_bool,
  60. "int": self._convert_int,
  61. "float": self._convert_float,
  62. "enum": self._convert_enum,
  63. "hostname": self._convert_hostname,
  64. "url": self._convert_url,
  65. "email": self._convert_email,
  66. }
  67. converter = converters.get(self.type)
  68. if converter:
  69. return converter(value)
  70. # Default to string conversion
  71. return str(value)
  72. def _convert_bool(self, value: Any) -> bool:
  73. """Convert value to boolean."""
  74. if isinstance(value, bool):
  75. return value
  76. if isinstance(value, str):
  77. lowered = value.strip().lower()
  78. if lowered in TRUE_VALUES:
  79. return True
  80. if lowered in FALSE_VALUES:
  81. return False
  82. raise ValueError("value must be a boolean (true/false)")
  83. def _convert_int(self, value: Any) -> Optional[int]:
  84. """Convert value to integer."""
  85. if isinstance(value, int):
  86. return value
  87. if isinstance(value, str) and value.strip() == "":
  88. return None
  89. try:
  90. return int(value)
  91. except (TypeError, ValueError) as exc:
  92. raise ValueError("value must be an integer") from exc
  93. def _convert_float(self, value: Any) -> Optional[float]:
  94. """Convert value to float."""
  95. if isinstance(value, float):
  96. return value
  97. if isinstance(value, str) and value.strip() == "":
  98. return None
  99. try:
  100. return float(value)
  101. except (TypeError, ValueError) as exc:
  102. raise ValueError("value must be a float") from exc
  103. def _convert_enum(self, value: Any) -> Optional[str]:
  104. """Convert value to enum option."""
  105. if value == "":
  106. return None
  107. val = str(value)
  108. if self.options and val not in self.options:
  109. raise ValueError(f"value must be one of: {', '.join(self.options)}")
  110. return val
  111. def _convert_hostname(self, value: Any) -> str:
  112. """Convert and validate hostname."""
  113. val = str(value).strip()
  114. if not val:
  115. return ""
  116. if val.lower() == "localhost":
  117. return val
  118. if not HOSTNAME_REGEX.fullmatch(val):
  119. raise ValueError("value must be a valid hostname")
  120. return val
  121. def _convert_url(self, value: Any) -> str:
  122. """Convert and validate URL."""
  123. val = str(value).strip()
  124. if not val:
  125. return ""
  126. parsed = urlparse(val)
  127. if not (parsed.scheme and parsed.netloc):
  128. raise ValueError("value must be a valid URL (include scheme and host)")
  129. return val
  130. def _convert_email(self, value: Any) -> str:
  131. """Convert and validate email."""
  132. val = str(value).strip()
  133. if not val:
  134. return ""
  135. if not EMAIL_REGEX.fullmatch(val):
  136. raise ValueError("value must be a valid email address")
  137. return val
  138. def get_typed_value(self) -> Any:
  139. """Return the stored value converted to the appropriate Python type."""
  140. return self.convert(self.value)
  141. # !SECTION
  142. # !SECTION
  143. # ----------------------------
  144. # SECTION: VariableSection Class
  145. # ----------------------------
  146. class VariableSection:
  147. """Groups variables together with shared metadata for presentation."""
  148. def __init__(self, data: dict[str, Any]) -> None:
  149. """Initialize VariableSection from a dictionary.
  150. Args:
  151. data: Dictionary containing section specification with required 'key' and 'title' keys
  152. """
  153. if not isinstance(data, dict):
  154. raise ValueError("VariableSection data must be a dictionary")
  155. if "key" not in data:
  156. raise ValueError("VariableSection data must contain 'key'")
  157. if "title" not in data:
  158. raise ValueError("VariableSection data must contain 'title'")
  159. self.key: str = data["key"]
  160. self.title: str = data["title"]
  161. self.variables: OrderedDict[str, Variable] = OrderedDict()
  162. self.prompt: Optional[str] = data.get("prompt")
  163. self.description: Optional[str] = data.get("description")
  164. self.toggle: Optional[str] = data.get("toggle")
  165. # Default "general" section to required=True, all others to required=False
  166. self.required: bool = data.get("required", data["key"] == "general")
  167. def variable_names(self) -> list[str]:
  168. return list(self.variables.keys())
  169. # !SECTION
  170. # --------------------------------
  171. # SECTION: VariableCollection Class
  172. # --------------------------------
  173. class VariableCollection:
  174. """Manages variables grouped by sections and builds Jinja context."""
  175. def __init__(self, spec: dict[str, Any]) -> None:
  176. """Initialize VariableCollection from a specification dictionary.
  177. Args:
  178. spec: Dictionary containing the complete variable specification structure
  179. Expected format (as used in compose.py):
  180. {
  181. "section_key": {
  182. "title": "Section Title",
  183. "prompt": "Optional prompt text",
  184. "toggle": "optional_toggle_var_name",
  185. "description": "Optional description",
  186. "vars": {
  187. "var_name": {
  188. "description": "Variable description",
  189. "type": "str",
  190. "default": "default_value",
  191. ...
  192. }
  193. }
  194. }
  195. }
  196. """
  197. if not isinstance(spec, dict):
  198. raise ValueError("Spec must be a dictionary")
  199. self._set: Dict[str, VariableSection] = {}
  200. # Initialize sections and their variables
  201. for section_key, section_data in spec.items():
  202. if not isinstance(section_data, dict):
  203. continue
  204. # Create section data with the key included
  205. section_init_data = {
  206. "key": section_key,
  207. "title": section_data.get("title", section_key.replace("_", " ").title()),
  208. "prompt": section_data.get("prompt"),
  209. "description": section_data.get("description"),
  210. "toggle": section_data.get("toggle"),
  211. "required": section_data.get("required", section_key == "general")
  212. }
  213. section = VariableSection(section_init_data)
  214. # Initialize variables in this section
  215. if "vars" in section_data:
  216. for var_name, var_data in section_data["vars"].items():
  217. # Add variable name to the data
  218. var_init_data = {"name": var_name, **var_data}
  219. variable = Variable(var_init_data)
  220. section.variables[var_name] = variable
  221. self._set[section_key] = section
  222. # -------------------------
  223. # SECTION: Helper Methods
  224. # -------------------------
  225. # NOTE: These helper methods reduce code duplication across module.py and prompt.py
  226. # by centralizing common variable collection operations
  227. def get_all_values(self) -> dict[str, Any]:
  228. """Get all variable values as a dictionary.
  229. Returns:
  230. Dictionary mapping variable names to their typed values
  231. """
  232. # NOTE: Eliminates the need to iterate through sections and variables manually
  233. # in module.py _extract_current_variable_values() method
  234. all_values = {}
  235. for section in self._set.values():
  236. for var_name, variable in section.variables.items():
  237. all_values[var_name] = variable.get_typed_value()
  238. return all_values
  239. def apply_overrides(self, overrides: dict[str, Any], origin_suffix: str = " -> cli") -> list[str]:
  240. """Apply multiple variable overrides at once.
  241. Args:
  242. overrides: Dictionary of variable names to values
  243. origin_suffix: Suffix to append to origins for overridden variables
  244. Returns:
  245. List of variable names that were successfully overridden
  246. """
  247. # NOTE: Replaces the complex _apply_cli_overrides() method in module.py
  248. # by centralizing override logic with proper error handling and origin tracking
  249. successful_overrides = []
  250. errors = []
  251. for var_name, value in overrides.items():
  252. try:
  253. # Find and update the variable
  254. found = False
  255. for section in self._set.values():
  256. if var_name in section.variables:
  257. variable = section.variables[var_name]
  258. # Convert and set the new value
  259. converted_value = variable.convert(value)
  260. variable.value = converted_value
  261. # Update origin to show override
  262. if variable.origin:
  263. variable.origin = variable.origin + origin_suffix
  264. else:
  265. variable.origin = origin_suffix.lstrip(" -> ")
  266. successful_overrides.append(var_name)
  267. found = True
  268. break
  269. if not found:
  270. logger.warning(f"Variable '{var_name}' not found in template")
  271. except ValueError as e:
  272. error_msg = f"Invalid override value for '{var_name}': {value} - {e}"
  273. errors.append(error_msg)
  274. logger.error(error_msg)
  275. if errors:
  276. # Log errors but don't stop the process
  277. logger.warning(f"Some CLI overrides failed: {'; '.join(errors)}")
  278. return successful_overrides
  279. # !SECTION
  280. # !SECTION