variables.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. from __future__ import annotations
  2. from collections import OrderedDict
  3. from dataclasses import dataclass, field
  4. from typing import Any, Dict, List, Optional, Set, Union
  5. from urllib.parse import urlparse
  6. import logging
  7. import re
  8. logger = logging.getLogger(__name__)
  9. # -----------------------
  10. # SECTION: Constants
  11. # -----------------------
  12. TRUE_VALUES = {"true", "1", "yes", "on"}
  13. FALSE_VALUES = {"false", "0", "no", "off"}
  14. HOSTNAME_REGEX = re.compile(r"^(?=.{1,253}$)(?!-)[A-Za-z0-9_-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9_-]{1,63}(?<!-))*$")
  15. EMAIL_REGEX = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$")
  16. # !SECTION
  17. # ----------------------
  18. # SECTION: Variable Class
  19. # ----------------------
  20. class Variable:
  21. """Represents a single templating variable with lightweight validation."""
  22. def __init__(self, data: dict[str, Any]) -> None:
  23. """Initialize Variable from a dictionary containing variable specification.
  24. Args:
  25. data: Dictionary containing variable specification with required 'name' key
  26. and optional keys: description, type, options, prompt, value, default, section, origin
  27. Raises:
  28. ValueError: If data is not a dict, missing 'name' key, or has invalid default value
  29. """
  30. # Validate input
  31. if not isinstance(data, dict):
  32. raise ValueError("Variable data must be a dictionary")
  33. if "name" not in data:
  34. raise ValueError("Variable data must contain 'name' key")
  35. # Initialize fields
  36. self.name: str = data["name"]
  37. self.description: Optional[str] = data.get("description") or data.get("display", "")
  38. self.type: str = data.get("type", "str")
  39. self.options: Optional[List[Any]] = data.get("options", [])
  40. self.prompt: Optional[str] = data.get("prompt")
  41. self.value: Any = data.get("value") if data.get("value") is not None else data.get("default")
  42. self.section: Optional[str] = data.get("section")
  43. self.origin: Optional[str] = data.get("origin")
  44. self.sensitive: bool = data.get("sensitive", False)
  45. # Optional extra explanation used by interactive prompts
  46. self.extra: Optional[str] = data.get("extra")
  47. # Validate and convert the default/initial value if present
  48. if self.value is not None:
  49. try:
  50. self.value = self.convert(self.value)
  51. except ValueError as exc:
  52. raise ValueError(f"Invalid default for variable '{self.name}': {exc}")
  53. # -------------------------
  54. # SECTION: Validation Helpers
  55. # -------------------------
  56. def _validate_not_empty(self, value: Any, converted_value: Any) -> None:
  57. """Validate that a value is not empty for non-boolean types."""
  58. if self.type not in ["bool"] and (converted_value is None or converted_value == ""):
  59. raise ValueError("value cannot be empty")
  60. def _validate_enum_option(self, value: str) -> None:
  61. """Validate that a value is in the allowed enum options."""
  62. if self.options and value not in self.options:
  63. raise ValueError(f"value must be one of: {', '.join(self.options)}")
  64. def _validate_regex_pattern(self, value: str, pattern: re.Pattern, error_msg: str) -> None:
  65. """Validate that a value matches a regex pattern."""
  66. if not pattern.fullmatch(value):
  67. raise ValueError(error_msg)
  68. def _validate_url_structure(self, parsed_url) -> None:
  69. """Validate that a parsed URL has required components."""
  70. if not (parsed_url.scheme and parsed_url.netloc):
  71. raise ValueError("value must be a valid URL (include scheme and host)")
  72. # !SECTION
  73. # -------------------------
  74. # SECTION: Type Conversion
  75. # -------------------------
  76. def convert(self, value: Any) -> Any:
  77. """Validate and convert a raw value based on the variable type."""
  78. if value is None:
  79. return None
  80. # Treat empty strings as None to avoid storing "" for missing values.
  81. if isinstance(value, str) and value.strip() == "":
  82. return None
  83. # Type conversion mapping for cleaner code
  84. converters = {
  85. "bool": self._convert_bool,
  86. "int": self._convert_int,
  87. "float": self._convert_float,
  88. "enum": self._convert_enum,
  89. "hostname": self._convert_hostname,
  90. "url": self._convert_url,
  91. "email": self._convert_email,
  92. }
  93. converter = converters.get(self.type)
  94. if converter:
  95. return converter(value)
  96. # Default to string conversion
  97. return str(value)
  98. def _convert_bool(self, value: Any) -> bool:
  99. """Convert value to boolean."""
  100. if isinstance(value, bool):
  101. return value
  102. if isinstance(value, str):
  103. lowered = value.strip().lower()
  104. if lowered in TRUE_VALUES:
  105. return True
  106. if lowered in FALSE_VALUES:
  107. return False
  108. raise ValueError("value must be a boolean (true/false)")
  109. def _convert_int(self, value: Any) -> Optional[int]:
  110. """Convert value to integer."""
  111. if isinstance(value, int):
  112. return value
  113. if isinstance(value, str) and value.strip() == "":
  114. return None
  115. try:
  116. return int(value)
  117. except (TypeError, ValueError) as exc:
  118. raise ValueError("value must be an integer") from exc
  119. def _convert_float(self, value: Any) -> Optional[float]:
  120. """Convert value to float."""
  121. if isinstance(value, float):
  122. return value
  123. if isinstance(value, str) and value.strip() == "":
  124. return None
  125. try:
  126. return float(value)
  127. except (TypeError, ValueError) as exc:
  128. raise ValueError("value must be a float") from exc
  129. def _convert_enum(self, value: Any) -> Optional[str]:
  130. """Convert value to enum option."""
  131. if value == "":
  132. return None
  133. val = str(value)
  134. self._validate_enum_option(val)
  135. return val
  136. def _convert_hostname(self, value: Any) -> str:
  137. """Convert and validate hostname."""
  138. val = str(value).strip()
  139. if not val:
  140. return None
  141. if val.lower() != "localhost":
  142. self._validate_regex_pattern(val, HOSTNAME_REGEX, "value must be a valid hostname")
  143. return val
  144. def _convert_url(self, value: Any) -> str:
  145. """Convert and validate URL."""
  146. val = str(value).strip()
  147. if not val:
  148. return None
  149. parsed = urlparse(val)
  150. self._validate_url_structure(parsed)
  151. return val
  152. def _convert_email(self, value: Any) -> str:
  153. """Convert and validate email."""
  154. val = str(value).strip()
  155. if not val:
  156. return None
  157. self._validate_regex_pattern(val, EMAIL_REGEX, "value must be a valid email address")
  158. return val
  159. def get_typed_value(self) -> Any:
  160. """Return the stored value converted to the appropriate Python type."""
  161. return self.convert(self.value)
  162. def to_dict(self) -> Dict[str, Any]:
  163. """Serialize Variable to a dictionary for storage.
  164. Returns:
  165. Dictionary representation of the variable with only relevant fields.
  166. """
  167. var_dict = {}
  168. if self.type:
  169. var_dict["type"] = self.type
  170. if self.value is not None:
  171. var_dict["default"] = self.value
  172. if self.description:
  173. var_dict["description"] = self.description
  174. if self.prompt:
  175. var_dict["prompt"] = self.prompt
  176. if self.sensitive:
  177. var_dict["sensitive"] = self.sensitive
  178. if self.extra:
  179. var_dict["extra"] = self.extra
  180. if self.options:
  181. var_dict["options"] = self.options
  182. if self.origin:
  183. var_dict["origin"] = self.origin
  184. return var_dict
  185. # -------------------------
  186. # SECTION: Display Methods
  187. # -------------------------
  188. def get_display_value(self, mask_sensitive: bool = True, max_length: int = 30) -> str:
  189. """Get formatted display value with optional masking and truncation.
  190. Args:
  191. mask_sensitive: If True, mask sensitive values with asterisks
  192. max_length: Maximum length before truncation (0 = no limit)
  193. Returns:
  194. Formatted string representation of the value
  195. """
  196. if self.value is None:
  197. return ""
  198. # Mask sensitive values
  199. if self.sensitive and mask_sensitive:
  200. return "********"
  201. # Convert to string
  202. display = str(self.value)
  203. # Truncate if needed
  204. if max_length > 0 and len(display) > max_length:
  205. return display[:max_length - 3] + "..."
  206. return display
  207. def get_normalized_default(self) -> Any:
  208. """Get normalized default value suitable for prompts and display.
  209. Handles type conversion and provides sensible defaults for different types.
  210. Especially useful for enum, bool, and int types in interactive prompts.
  211. Returns:
  212. Normalized default value appropriate for the variable type
  213. """
  214. try:
  215. typed = self.get_typed_value()
  216. except Exception:
  217. typed = self.value
  218. # Enum: ensure default is valid option
  219. if self.type == "enum":
  220. if not self.options:
  221. return typed
  222. # If typed is invalid or missing, use first option
  223. if typed is None or str(typed) not in self.options:
  224. return self.options[0]
  225. return str(typed)
  226. # Boolean: return as bool type
  227. if self.type == "bool":
  228. if isinstance(typed, bool):
  229. return typed
  230. return None if typed is None else bool(typed)
  231. # Integer: return as int type
  232. if self.type == "int":
  233. try:
  234. return int(typed) if typed is not None and typed != "" else None
  235. except Exception:
  236. return None
  237. # Default: return string or None
  238. return None if typed is None else str(typed)
  239. def get_prompt_text(self) -> str:
  240. """Get formatted prompt text for interactive input.
  241. Returns:
  242. Prompt text with optional type hints and descriptions
  243. """
  244. prompt_text = self.prompt or self.description or self.name
  245. # Add type hint for semantic types if there's a default
  246. if self.value is not None and self.type in ["hostname", "email", "url"]:
  247. prompt_text += f" ({self.type})"
  248. return prompt_text
  249. def get_validation_hint(self) -> Optional[str]:
  250. """Get validation hint for prompts (e.g., enum options).
  251. Returns:
  252. Formatted hint string or None if no hint needed
  253. """
  254. hints = []
  255. # Add enum options
  256. if self.type == "enum" and self.options:
  257. hints.append(f"Options: {', '.join(self.options)}")
  258. # Add extra help text
  259. if self.extra:
  260. hints.append(self.extra)
  261. return " — ".join(hints) if hints else None
  262. def clone(self, update: Optional[Dict[str, Any]] = None) -> 'Variable':
  263. """Create a deep copy of the variable with optional field updates.
  264. This is more efficient than converting to dict and back when copying variables.
  265. Args:
  266. update: Optional dictionary of field updates to apply to the clone
  267. Returns:
  268. New Variable instance with copied data
  269. Example:
  270. var2 = var1.clone(update={'origin': 'template'})
  271. """
  272. data = {
  273. 'name': self.name,
  274. 'type': self.type,
  275. 'value': self.value,
  276. 'description': self.description,
  277. 'prompt': self.prompt,
  278. 'options': self.options.copy() if self.options else None,
  279. 'section': self.section,
  280. 'origin': self.origin,
  281. 'sensitive': self.sensitive,
  282. 'extra': self.extra,
  283. }
  284. # Apply updates if provided
  285. if update:
  286. data.update(update)
  287. return Variable(data)
  288. # !SECTION
  289. # !SECTION
  290. # ----------------------------
  291. # SECTION: VariableSection Class
  292. # ----------------------------
  293. class VariableSection:
  294. """Groups variables together with shared metadata for presentation."""
  295. def __init__(self, data: dict[str, Any]) -> None:
  296. """Initialize VariableSection from a dictionary.
  297. Args:
  298. data: Dictionary containing section specification with required 'key' and 'title' keys
  299. """
  300. if not isinstance(data, dict):
  301. raise ValueError("VariableSection data must be a dictionary")
  302. if "key" not in data:
  303. raise ValueError("VariableSection data must contain 'key'")
  304. if "title" not in data:
  305. raise ValueError("VariableSection data must contain 'title'")
  306. self.key: str = data["key"]
  307. self.title: str = data["title"]
  308. self.variables: OrderedDict[str, Variable] = OrderedDict()
  309. self.prompt: Optional[str] = data.get("prompt")
  310. self.description: Optional[str] = data.get("description")
  311. self.toggle: Optional[str] = data.get("toggle")
  312. # Default "general" section to required=True, all others to required=False
  313. self.required: bool = data.get("required", data["key"] == "general")
  314. def variable_names(self) -> list[str]:
  315. return list(self.variables.keys())
  316. def to_dict(self) -> Dict[str, Any]:
  317. """Serialize VariableSection to a dictionary for storage.
  318. Returns:
  319. Dictionary representation of the section with all metadata and variables.
  320. """
  321. section_dict = {}
  322. if self.title:
  323. section_dict["title"] = self.title
  324. if self.description:
  325. section_dict["description"] = self.description
  326. if self.prompt:
  327. section_dict["prompt"] = self.prompt
  328. if self.toggle:
  329. section_dict["toggle"] = self.toggle
  330. # Always store required flag
  331. section_dict["required"] = self.required
  332. # Serialize all variables using their own to_dict method
  333. section_dict["vars"] = {}
  334. for var_name, variable in self.variables.items():
  335. section_dict["vars"][var_name] = variable.to_dict()
  336. return section_dict
  337. # -------------------------
  338. # SECTION: State Methods
  339. # -------------------------
  340. def is_enabled(self) -> bool:
  341. """Check if section is currently enabled based on toggle variable.
  342. Returns:
  343. True if section is enabled (no toggle or toggle is True), False otherwise
  344. """
  345. if not self.toggle:
  346. return True
  347. toggle_var = self.variables.get(self.toggle)
  348. if not toggle_var:
  349. return True
  350. try:
  351. return bool(toggle_var.get_typed_value())
  352. except Exception:
  353. return False
  354. def get_toggle_value(self) -> Optional[bool]:
  355. """Get the current value of the toggle variable.
  356. Returns:
  357. Boolean value of toggle variable, or None if no toggle exists
  358. """
  359. if not self.toggle:
  360. return None
  361. toggle_var = self.variables.get(self.toggle)
  362. if not toggle_var:
  363. return None
  364. try:
  365. return bool(toggle_var.get_typed_value())
  366. except Exception:
  367. return None
  368. def clone(self, origin_update: Optional[str] = None) -> 'VariableSection':
  369. """Create a deep copy of the section with all variables.
  370. This is more efficient than converting to dict and back when copying sections.
  371. Args:
  372. origin_update: Optional origin string to apply to all cloned variables
  373. Returns:
  374. New VariableSection instance with deep-copied variables
  375. Example:
  376. section2 = section1.clone(origin_update='template')
  377. """
  378. # Create new section with same metadata
  379. cloned = VariableSection({
  380. 'key': self.key,
  381. 'title': self.title,
  382. 'prompt': self.prompt,
  383. 'description': self.description,
  384. 'toggle': self.toggle,
  385. 'required': self.required,
  386. })
  387. # Deep copy all variables
  388. for var_name, variable in self.variables.items():
  389. if origin_update:
  390. cloned.variables[var_name] = variable.clone(update={'origin': origin_update})
  391. else:
  392. cloned.variables[var_name] = variable.clone()
  393. return cloned
  394. # !SECTION
  395. # !SECTION
  396. # --------------------------------
  397. # SECTION: VariableCollection Class
  398. # --------------------------------
  399. class VariableCollection:
  400. """Manages variables grouped by sections and builds Jinja context."""
  401. def __init__(self, spec: dict[str, Any]) -> None:
  402. """Initialize VariableCollection from a specification dictionary.
  403. Args:
  404. spec: Dictionary containing the complete variable specification structure
  405. Expected format (as used in compose.py):
  406. {
  407. "section_key": {
  408. "title": "Section Title",
  409. "prompt": "Optional prompt text",
  410. "toggle": "optional_toggle_var_name",
  411. "description": "Optional description",
  412. "vars": {
  413. "var_name": {
  414. "description": "Variable description",
  415. "type": "str",
  416. "default": "default_value",
  417. ...
  418. }
  419. }
  420. }
  421. }
  422. """
  423. if not isinstance(spec, dict):
  424. raise ValueError("Spec must be a dictionary")
  425. self._sections: Dict[str, VariableSection] = {}
  426. # NOTE: The _variable_map provides a flat, O(1) lookup for any variable by its name,
  427. # avoiding the need to iterate through sections. It stores references to the same
  428. # Variable objects contained in the _set structure.
  429. self._variable_map: Dict[str, Variable] = {}
  430. self._initialize_sections(spec)
  431. def _initialize_sections(self, spec: dict[str, Any]) -> None:
  432. """Initialize sections from the spec."""
  433. for section_key, section_data in spec.items():
  434. if not isinstance(section_data, dict):
  435. continue
  436. section = self._create_section(section_key, section_data)
  437. # Guard against None from empty YAML sections (vars: with no content)
  438. vars_data = section_data.get("vars") or {}
  439. self._initialize_variables(section, vars_data)
  440. self._sections[section_key] = section
  441. def _create_section(self, key: str, data: dict[str, Any]) -> VariableSection:
  442. """Create a VariableSection from data."""
  443. section_init_data = {
  444. "key": key,
  445. "title": data.get("title", key.replace("_", " ").title()),
  446. "prompt": data.get("prompt"),
  447. "description": data.get("description"),
  448. "toggle": data.get("toggle"),
  449. "required": data.get("required", key == "general")
  450. }
  451. return VariableSection(section_init_data)
  452. def _initialize_variables(self, section: VariableSection, vars_data: dict[str, Any]) -> None:
  453. """Initialize variables for a section."""
  454. # Guard against None from empty YAML sections
  455. if vars_data is None:
  456. vars_data = {}
  457. for var_name, var_data in vars_data.items():
  458. var_init_data = {"name": var_name, **var_data}
  459. variable = Variable(var_init_data)
  460. section.variables[var_name] = variable
  461. # NOTE: Populate the direct lookup map for efficient access.
  462. self._variable_map[var_name] = variable
  463. # -------------------------
  464. # SECTION: Public API Methods
  465. # -------------------------
  466. def get_sections(self) -> Dict[str, VariableSection]:
  467. """Get all sections in the collection."""
  468. return self._sections.copy()
  469. def get_section(self, key: str) -> Optional[VariableSection]:
  470. """Get a specific section by its key."""
  471. return self._sections.get(key)
  472. def has_sections(self) -> bool:
  473. """Check if the collection has any sections."""
  474. return bool(self._sections)
  475. def get_all_values(self) -> dict[str, Any]:
  476. """Get all variable values as a dictionary."""
  477. # NOTE: This method is optimized to use the _variable_map for direct O(1) access
  478. # to each variable, which is much faster than iterating through sections.
  479. all_values = {}
  480. for var_name, variable in self._variable_map.items():
  481. all_values[var_name] = variable.get_typed_value()
  482. return all_values
  483. def get_sensitive_variables(self) -> Dict[str, Any]:
  484. """Get only the sensitive variables with their values."""
  485. return {name: var.value for name, var in self._variable_map.items() if var.sensitive and var.value}
  486. # !SECTION
  487. # -------------------------
  488. # SECTION: Helper Methods
  489. # -------------------------
  490. # NOTE: These helper methods reduce code duplication across module.py and prompt.py
  491. # by centralizing common variable collection operations
  492. def apply_defaults(self, defaults: dict[str, Any], origin: str = "cli") -> list[str]:
  493. """Apply default values to variables, updating their origin.
  494. Args:
  495. defaults: Dictionary mapping variable names to their default values
  496. origin: Source of these defaults (e.g., 'config', 'cli')
  497. Returns:
  498. List of variable names that were successfully updated
  499. """
  500. # NOTE: This method uses the _variable_map for a significant performance gain,
  501. # as it allows direct O(1) lookup of variables instead of iterating
  502. # through all sections to find a match.
  503. successful = []
  504. errors = []
  505. for var_name, value in defaults.items():
  506. try:
  507. variable = self._variable_map.get(var_name)
  508. if not variable:
  509. logger.warning(f"Variable '{var_name}' not found in template")
  510. continue
  511. # Convert and set the new value
  512. converted_value = variable.convert(value)
  513. variable.value = converted_value
  514. # Set origin to the current source (not a chain)
  515. variable.origin = origin
  516. successful.append(var_name)
  517. except ValueError as e:
  518. error_msg = f"Invalid value for '{var_name}': {value} - {e}"
  519. errors.append(error_msg)
  520. logger.error(error_msg)
  521. if errors:
  522. logger.warning(f"Some defaults failed to apply: {'; '.join(errors)}")
  523. return successful
  524. def validate_all(self) -> None:
  525. """Validate all variables in the collection, skipping disabled sections."""
  526. errors: list[str] = []
  527. for section in self._sections.values():
  528. # Check if the section is disabled by a toggle
  529. if section.toggle:
  530. toggle_var = section.variables.get(section.toggle)
  531. if toggle_var and not toggle_var.get_typed_value():
  532. logger.debug(f"Skipping validation for disabled section: '{section.key}'")
  533. continue # Skip this entire section
  534. # Validate each variable in the section
  535. for var_name, variable in section.variables.items():
  536. try:
  537. # If value is None, treat as missing
  538. if variable.value is None:
  539. errors.append(f"{section.key}.{var_name} (missing)")
  540. continue
  541. # Attempt to convert/validate typed value
  542. typed = variable.get_typed_value()
  543. # For non-boolean types, treat None or empty string as invalid
  544. if variable.type not in ("bool",) and (typed is None or typed == ""):
  545. errors.append(f"{section.key}.{var_name} (empty)")
  546. except ValueError as e:
  547. errors.append(f"{section.key}.{var_name} (invalid: {e})")
  548. if errors:
  549. error_msg = "Variable validation failed: " + ", ".join(errors)
  550. logger.error(error_msg)
  551. raise ValueError(error_msg)
  552. def merge(self, other_spec: Union[Dict[str, Any], 'VariableCollection'], origin: str = "override") -> 'VariableCollection':
  553. """Merge another spec or VariableCollection into this one with precedence tracking.
  554. OPTIMIZED: Works directly on objects without dict conversions for better performance.
  555. The other spec/collection has higher precedence and will override values in self.
  556. Creates a new VariableCollection with merged data.
  557. Args:
  558. other_spec: Either a spec dictionary or another VariableCollection to merge
  559. origin: Origin label for variables from other_spec (e.g., 'template', 'config')
  560. Returns:
  561. New VariableCollection with merged data
  562. Example:
  563. module_vars = VariableCollection(module_spec)
  564. template_vars = module_vars.merge(template_spec, origin='template')
  565. # Variables from template_spec override module_spec
  566. # Origins tracked: 'module' or 'module -> template'
  567. """
  568. # Convert dict to VariableCollection if needed (only once)
  569. if isinstance(other_spec, dict):
  570. other = VariableCollection(other_spec)
  571. else:
  572. other = other_spec
  573. # Create new collection without calling __init__ (optimization)
  574. merged = VariableCollection.__new__(VariableCollection)
  575. merged._sections = {}
  576. merged._variable_map = {}
  577. # First pass: clone sections from self
  578. for section_key, self_section in self._sections.items():
  579. if section_key in other._sections:
  580. # Section exists in both - will merge
  581. merged._sections[section_key] = self._merge_sections(
  582. self_section,
  583. other._sections[section_key],
  584. origin
  585. )
  586. else:
  587. # Section only in self - clone it
  588. merged._sections[section_key] = self_section.clone()
  589. # Second pass: add sections that only exist in other
  590. for section_key, other_section in other._sections.items():
  591. if section_key not in merged._sections:
  592. # New section from other - clone with origin update
  593. merged._sections[section_key] = other_section.clone(origin_update=origin)
  594. # Rebuild variable map for O(1) lookups
  595. for section in merged._sections.values():
  596. for var_name, variable in section.variables.items():
  597. merged._variable_map[var_name] = variable
  598. return merged
  599. def _infer_origin_from_context(self) -> str:
  600. """Infer origin from existing variables (fallback)."""
  601. for section in self._sections.values():
  602. for variable in section.variables.values():
  603. if variable.origin:
  604. return variable.origin
  605. return "template"
  606. def _merge_sections(self, self_section: VariableSection, other_section: VariableSection, origin: str) -> VariableSection:
  607. """Merge two sections, with other_section taking precedence.
  608. Args:
  609. self_section: Base section
  610. other_section: Section to merge in (takes precedence)
  611. origin: Origin label for merged variables
  612. Returns:
  613. New merged VariableSection
  614. """
  615. # Start with a clone of self_section
  616. merged_section = self_section.clone()
  617. # Update section metadata from other (other takes precedence)
  618. if other_section.title:
  619. merged_section.title = other_section.title
  620. if other_section.prompt:
  621. merged_section.prompt = other_section.prompt
  622. if other_section.description:
  623. merged_section.description = other_section.description
  624. if other_section.toggle:
  625. merged_section.toggle = other_section.toggle
  626. # Required flag always updated
  627. merged_section.required = other_section.required
  628. # Merge variables
  629. for var_name, other_var in other_section.variables.items():
  630. if var_name in merged_section.variables:
  631. # Variable exists in both - merge with other taking precedence
  632. self_var = merged_section.variables[var_name]
  633. # Build update dict with other's values taking precedence
  634. update = {}
  635. if other_var.type:
  636. update['type'] = other_var.type
  637. if other_var.value is not None:
  638. update['value'] = other_var.value
  639. if other_var.description:
  640. update['description'] = other_var.description
  641. if other_var.prompt:
  642. update['prompt'] = other_var.prompt
  643. if other_var.options:
  644. update['options'] = other_var.options
  645. if other_var.sensitive:
  646. update['sensitive'] = other_var.sensitive
  647. if other_var.extra:
  648. update['extra'] = other_var.extra
  649. # Update origin tracking (only keep the current source, not the chain)
  650. update['origin'] = origin
  651. # Clone with updates
  652. merged_section.variables[var_name] = self_var.clone(update=update)
  653. else:
  654. # New variable from other - clone with origin
  655. merged_section.variables[var_name] = other_var.clone(update={'origin': origin})
  656. return merged_section
  657. def filter_to_used(self, used_variables: Set[str], keep_sensitive: bool = True) -> 'VariableCollection':
  658. """Filter collection to only variables that are used (or sensitive).
  659. OPTIMIZED: Works directly on objects without dict conversions for better performance.
  660. Creates a new VariableCollection containing only the variables in used_variables.
  661. Sections with no remaining variables are removed.
  662. Args:
  663. used_variables: Set of variable names that are actually used
  664. keep_sensitive: If True, also keep sensitive variables even if not in used set
  665. Returns:
  666. New VariableCollection with filtered variables
  667. Example:
  668. all_vars = VariableCollection(spec)
  669. used_vars = all_vars.filter_to_used({'var1', 'var2', 'var3'})
  670. # Only var1, var2, var3 (and any sensitive vars) remain
  671. """
  672. # Create new collection without calling __init__ (optimization)
  673. filtered = VariableCollection.__new__(VariableCollection)
  674. filtered._sections = {}
  675. filtered._variable_map = {}
  676. # Filter each section
  677. for section_key, section in self._sections.items():
  678. # Create a new section with same metadata
  679. filtered_section = VariableSection({
  680. 'key': section.key,
  681. 'title': section.title,
  682. 'prompt': section.prompt,
  683. 'description': section.description,
  684. 'toggle': section.toggle,
  685. 'required': section.required,
  686. })
  687. # Clone only the variables that should be included
  688. for var_name, variable in section.variables.items():
  689. # Include if used OR if sensitive (and keep_sensitive is True)
  690. should_include = (
  691. var_name in used_variables or
  692. (keep_sensitive and variable.sensitive)
  693. )
  694. if should_include:
  695. filtered_section.variables[var_name] = variable.clone()
  696. # Only add section if it has variables
  697. if filtered_section.variables:
  698. filtered._sections[section_key] = filtered_section
  699. # Add variables to map
  700. for var_name, variable in filtered_section.variables.items():
  701. filtered._variable_map[var_name] = variable
  702. return filtered
  703. def get_all_variable_names(self) -> Set[str]:
  704. """Get set of all variable names across all sections.
  705. Returns:
  706. Set of all variable names
  707. """
  708. return set(self._variable_map.keys())
  709. # !SECTION
  710. # !SECTION