prompt_manager.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from __future__ import annotations
  2. import logging
  3. from typing import Any, Callable
  4. from rich.console import Console
  5. from rich.prompt import Prompt
  6. from ..display import DisplayManager
  7. from ..input import InputManager
  8. from ..template import Variable, VariableCollection
  9. logger = logging.getLogger(__name__)
  10. class PromptHandler:
  11. """Simple interactive prompt handler for collecting template variables."""
  12. def __init__(self) -> None:
  13. self.console = Console()
  14. self.display = DisplayManager()
  15. def _handle_section_toggle(self, section, collected: dict[str, Any]) -> bool:
  16. """Handle section toggle variable and return whether section should be enabled."""
  17. if not section.toggle:
  18. return True
  19. toggle_var = section.variables.get(section.toggle)
  20. if not toggle_var:
  21. return True
  22. current_value = toggle_var.convert(toggle_var.value)
  23. new_value = self._prompt_variable(toggle_var, _required=False)
  24. if new_value != current_value:
  25. collected[toggle_var.name] = new_value
  26. toggle_var.value = new_value
  27. return section.is_enabled()
  28. def _should_skip_variable(
  29. self,
  30. var_name: str,
  31. section,
  32. variables: VariableCollection,
  33. section_enabled: bool,
  34. ) -> bool:
  35. """Determine if a variable should be skipped during collection."""
  36. if section.toggle and var_name == section.toggle:
  37. return True
  38. if not variables.is_variable_satisfied(var_name):
  39. logger.debug(f"Skipping variable '{var_name}' - needs not satisfied")
  40. return True
  41. if not section_enabled:
  42. logger.debug(f"Skipping variable '{var_name}' from disabled section '{section.key}'")
  43. return True
  44. return False
  45. def _collect_variable_value(self, variable: Variable, collected: dict[str, Any]) -> None:
  46. """Collect a single variable value and update if changed."""
  47. current_value = variable.convert(variable.value)
  48. new_value = self._prompt_variable(variable, _required=False)
  49. if variable.autogenerated and new_value is None:
  50. collected[variable.name] = None
  51. variable.value = None
  52. elif new_value != current_value:
  53. collected[variable.name] = new_value
  54. variable.value = new_value
  55. def collect_variables(self, variables: VariableCollection) -> dict[str, Any]:
  56. """Collect values for variables by iterating through sections.
  57. Args:
  58. variables: VariableCollection with organized sections and variables
  59. Returns:
  60. Dict of variable names to collected values
  61. """
  62. input_mgr = InputManager()
  63. if not input_mgr.confirm("Customize any settings?", default=False):
  64. logger.info("User opted to keep all default values")
  65. return {}
  66. collected: dict[str, Any] = {}
  67. for _section_key, section in variables.get_sections().items():
  68. if not section.variables:
  69. continue
  70. self.display.section(section.title, section.description)
  71. section_enabled = self._handle_section_toggle(section, collected)
  72. for var_name, variable in section.variables.items():
  73. if self._should_skip_variable(var_name, section, variables, section_enabled):
  74. continue
  75. self._collect_variable_value(variable, collected)
  76. logger.info(f"Variable collection completed. Collected {len(collected)} values")
  77. return collected
  78. def _prompt_variable(self, variable: Variable, _required: bool = False) -> Any:
  79. """Prompt for a single variable value based on its type.
  80. Args:
  81. variable: The variable to prompt for
  82. _required: Whether the containing section is required
  83. (unused, kept for API compatibility)
  84. Returns:
  85. The validated value entered by the user
  86. """
  87. logger.debug(f"Prompting for variable '{variable.name}' (type: {variable.type})")
  88. # Use variable's native methods for prompt text and default value
  89. prompt_text = variable.get_prompt_text()
  90. default_value = variable.get_normalized_default()
  91. has_explicit_default = "default" in variable._explicit_fields or "value" in variable._explicit_fields
  92. has_applied_default = variable.origin in {"config", "var-file", "cli"}
  93. if (
  94. not has_explicit_default
  95. and not has_applied_default
  96. and not variable.autogenerated
  97. and not variable.is_required()
  98. ):
  99. default_value = None
  100. # Add lock icon before default value for secret or autogenerated variables
  101. if variable.is_secret() or variable.autogenerated:
  102. # Format: "Prompt text 🔒 (default)"
  103. # The lock icon goes between the text and the default value in parentheses
  104. prompt_text = f"{prompt_text} {self.display.get_lock_icon()}"
  105. if variable.config.placeholder:
  106. prompt_text = f"{prompt_text} [dim]({variable.config.placeholder})[/dim]"
  107. # Check if this specific variable is required (has no default and not autogenerated)
  108. var_is_required = variable.is_required()
  109. # If variable is required, mark it in the prompt
  110. if var_is_required:
  111. prompt_text = f"{prompt_text} [bold red]*required[/bold red]"
  112. allow_empty = not var_is_required and default_value is None
  113. handler = self._get_prompt_handler(variable, allow_empty=allow_empty)
  114. # Add validation hint (includes both extra text and enum options)
  115. hint = variable.get_validation_hint()
  116. if hint:
  117. prompt_text = f"{prompt_text} [dim]({hint})[/dim]"
  118. while True:
  119. try:
  120. raw = handler(prompt_text, default_value)
  121. # Use Variable's centralized validation method that handles:
  122. # - Type conversion
  123. # - Autogenerated variable detection
  124. # - Required field validation
  125. return variable.validate_and_convert(raw, check_required=True)
  126. # Return the converted value (caller will update variable.value)
  127. except ValueError as exc:
  128. # Conversion/validation failed — show a consistent error message and retry
  129. self._show_validation_error(str(exc))
  130. except Exception as e:
  131. # Unexpected error — log and retry using the stored (unconverted) value
  132. logger.error(f"Error prompting for variable '{variable.name}': {e!s}")
  133. default_value = variable.value
  134. handler = self._get_prompt_handler(variable, allow_empty=allow_empty)
  135. def _get_prompt_handler(self, variable: Variable, allow_empty: bool = False) -> Callable:
  136. """Return the prompt function for a variable type."""
  137. handlers = {
  138. "bool": lambda text, default: self._prompt_bool(text, default, allow_empty=allow_empty),
  139. "int": lambda text, default: self._prompt_int(
  140. text,
  141. default,
  142. allow_empty=allow_empty,
  143. min_value=variable.config.min if variable.config else None,
  144. max_value=variable.config.max if variable.config else None,
  145. ),
  146. # For enum prompts we pass the variable.extra through so options and extra
  147. # can be combined into a single inline hint.
  148. "enum": lambda text, default: self._prompt_enum(
  149. text,
  150. variable.options or [],
  151. default,
  152. allow_empty=allow_empty,
  153. _extra=getattr(variable, "extra", None),
  154. ),
  155. }
  156. return handlers.get(
  157. variable.type,
  158. lambda text, default: self._prompt_string(text, default, is_secret=variable.is_secret()),
  159. )
  160. def _show_validation_error(self, message: str) -> None:
  161. """Display validation feedback consistently."""
  162. self.display.error(message)
  163. def _prompt_string(self, prompt_text: str, default: Any = None, is_secret: bool = False) -> str | None:
  164. if is_secret:
  165. value = Prompt.ask(
  166. prompt_text,
  167. default="",
  168. show_default=False,
  169. password=True,
  170. )
  171. stripped = value.strip() if value else None
  172. if stripped:
  173. return stripped
  174. return default
  175. value = Prompt.ask(
  176. prompt_text,
  177. default=str(default) if default is not None else "",
  178. show_default=True,
  179. password=False,
  180. )
  181. stripped = value.strip() if value else None
  182. return stripped if stripped else None
  183. def _prompt_bool(self, prompt_text: str, default: Any = None, allow_empty: bool = False) -> bool | str | None:
  184. input_mgr = InputManager()
  185. if allow_empty and default is None:
  186. value = Prompt.ask(prompt_text, default="", show_default=False)
  187. stripped = value.strip() if value else ""
  188. return stripped if stripped else None
  189. if default is None:
  190. return input_mgr.confirm(prompt_text, default=None)
  191. converted = default if isinstance(default, bool) else str(default).lower() in ("true", "1", "yes", "on")
  192. return input_mgr.confirm(prompt_text, default=converted)
  193. def _prompt_int(
  194. self,
  195. prompt_text: str,
  196. default: Any = None,
  197. allow_empty: bool = False,
  198. min_value: int | None = None,
  199. max_value: int | None = None,
  200. ) -> int | str | None:
  201. converted = None
  202. if default is not None:
  203. try:
  204. converted = int(default)
  205. except (ValueError, TypeError):
  206. logger.warning(f"Invalid default integer value: {default}")
  207. if allow_empty and converted is None:
  208. value = Prompt.ask(prompt_text, default="", show_default=False)
  209. stripped = value.strip() if value else ""
  210. return stripped if stripped else None
  211. input_mgr = InputManager()
  212. return input_mgr.integer(
  213. prompt_text,
  214. default=converted,
  215. min_value=min_value,
  216. max_value=max_value,
  217. )
  218. def _prompt_enum(
  219. self,
  220. prompt_text: str,
  221. options: list[str],
  222. default: Any = None,
  223. allow_empty: bool = False,
  224. _extra: str | None = None,
  225. ) -> str | None:
  226. """Prompt for enum selection with validation.
  227. Note: prompt_text should already include hint from variable.get_validation_hint()
  228. but we keep this for backward compatibility and fallback.
  229. """
  230. if not options:
  231. return self._prompt_string(prompt_text, default)
  232. # Validate default is in options
  233. if default and str(default) not in options:
  234. default = options[0]
  235. if allow_empty and default is None:
  236. value = Prompt.ask(prompt_text, default="", show_default=False)
  237. stripped = value.strip() if value else ""
  238. return stripped if stripped else None
  239. while True:
  240. value = Prompt.ask(
  241. prompt_text,
  242. default=str(default) if default else options[0],
  243. show_default=True,
  244. )
  245. if value in options:
  246. return value
  247. self.console.print(f"[red]Invalid choice. Select from: {', '.join(options)}[/red]")