diff --git a/lua/python/treesitter/commands.lua b/lua/python/treesitter/commands.lua index 7afcc32..e879bbb 100644 --- a/lua/python/treesitter/commands.lua +++ b/lua/python/treesitter/commands.lua @@ -209,6 +209,26 @@ function PythonTreeSitterCommands.ts_wrap_at_cursor(subtitute_option) end) end +--- +---@param node TSNode the current ts node we are checking for parents +---@return string callText check if this node has a "call" type node 3 parents up +--- this is used for checking on "".format() calls for strings. +local function checkForFStringCallParent(node) + local callStatus, callText = pcall(function() + local callNode = node:parent():parent():parent() + if callNode then + local text = getNodeText(callNode) + return text + end + return "" + end) -- Get potential function call on string for .format() + + if not callStatus then + callText = "" + end + return callText +end + function PythonTreeSitterCommands.pythonFStr() local maxCharacters = 200 -- safeguard to prevent converting invalid code local node = getNodeAtCursor() @@ -217,6 +237,8 @@ function PythonTreeSitterCommands.pythonFStr() end local strNode + local callText = checkForFStringCallParent(node) + if node:type() == "string" then strNode = node elseif node:type():find("^string_") then @@ -239,10 +261,12 @@ function PythonTreeSitterCommands.pythonFStr() return end -- safeguard on converting invalid code + local isFormatString = callText:find([[^.*["']%.format%(]]) + local isRString = text:find("^r") local isFString = text:find("^r?f") -- rf -> raw-formatted-string local hasBraces = text:find("{.-[^%d,%s].-}") -- nonRegex-braces, see #12 and #15 - if not isFString and hasBraces then + if (not isFString and not isFormatString and not isRString) and hasBraces then text = "f" .. text replaceNodeText(strNode, text) elseif isFString and not hasBraces then diff --git a/tests/test_text_actions.lua b/tests/test_text_actions.lua index eb4604e..7cf481b 100644 --- a/tests/test_text_actions.lua +++ b/tests/test_text_actions.lua @@ -25,19 +25,31 @@ local get_lines = function() return child.api.nvim_buf_get_lines(0, 0, -1, true) end -T["text_actions"] = MiniTest.new_set({ - n_retry = 3, +T["f-string"] = MiniTest.new_set({ hooks = { pre_case = function() child.cmd("e _not_existing_new_buffer.py") + child.type_keys("cc", [["TEST"]], "", "0") end, }, }) -T["text_actions"]["insert_f_string"] = function() - child.type_keys("i", [[print("{foo}")]], "") +T["f-string"]["insert f string"] = function() + child.cmd("e! _not_existing_new_buffer.py") + child.type_keys("cc", [["{foo}"]], "", "hh", "i", "") + eq(get_lines(), { [[f"{foo}"]] }) +end + +T["f-string"]["skip on r"] = function() + child.cmd("e! _not_existing_new_buffer.py") + child.type_keys("cc", [[r"{foo}"]], "", "hh", "i", "") + eq(get_lines(), { [[r"{foo}"]] }) +end - eq(get_lines(), { [[print(f"{foo}")]] }) +T["f-string"]["skip on format"] = function() + child.cmd("e! _not_existing_new_buffer.py") + child.type_keys("cc", [["{foo}".format()]], "", "0lll", "i", "") + eq(get_lines(), { [["{foo}".format()]] }) end -- Return test set which will be collected and execute inside `MiniTest.run()`