Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion lua/python/treesitter/commands.lua
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ function PythonTreeSitterCommands.ts_wrap_at_cursor(subtitute_option)
end)
end

---
---@param node TSNode the current ts node we are checking for parents
---@return string callText check if this node has a "call" type node 3 parents up
--- this is used for checking on "".format() calls for strings.
local function checkForFStringCallParent(node)
local callStatus, callText = pcall(function()
local callNode = node:parent():parent():parent()
if callNode then
local text = getNodeText(callNode)
return text
end
return ""
end) -- Get potential function call on string for .format()

if not callStatus then
callText = ""
end
return callText
end

function PythonTreeSitterCommands.pythonFStr()
local maxCharacters = 200 -- safeguard to prevent converting invalid code
local node = getNodeAtCursor()
Expand All @@ -217,6 +237,8 @@ function PythonTreeSitterCommands.pythonFStr()
end

local strNode
local callText = checkForFStringCallParent(node)

if node:type() == "string" then
strNode = node
elseif node:type():find("^string_") then
Expand All @@ -239,10 +261,12 @@ function PythonTreeSitterCommands.pythonFStr()
return
end -- safeguard on converting invalid code

local isFormatString = callText:find([[^.*["']%.format%(]])
local isRString = text:find("^r")
local isFString = text:find("^r?f") -- rf -> raw-formatted-string
local hasBraces = text:find("{.-[^%d,%s].-}") -- nonRegex-braces, see #12 and #15

if not isFString and hasBraces then
if (not isFString and not isFormatString and not isRString) and hasBraces then
text = "f" .. text
replaceNodeText(strNode, text)
elseif isFString and not hasBraces then
Expand Down
22 changes: 17 additions & 5 deletions tests/test_text_actions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,31 @@ local get_lines = function()
return child.api.nvim_buf_get_lines(0, 0, -1, true)
end

T["text_actions"] = MiniTest.new_set({
n_retry = 3,
T["f-string"] = MiniTest.new_set({
hooks = {
pre_case = function()
child.cmd("e _not_existing_new_buffer.py")
child.type_keys("cc", [["TEST"]], "<Esc>", "0")
end,
},
})

T["text_actions"]["insert_f_string"] = function()
child.type_keys("i", [[print("{foo}")]], "<left><esc>")
T["f-string"]["insert f string"] = function()
child.cmd("e! _not_existing_new_buffer.py")
child.type_keys("cc", [["{foo}"]], "<Esc>", "hh", "i", "<Esc>")
eq(get_lines(), { [[f"{foo}"]] })
end

T["f-string"]["skip on r"] = function()
child.cmd("e! _not_existing_new_buffer.py")
child.type_keys("cc", [[r"{foo}"]], "<Esc>", "hh", "i", "<Esc>")
eq(get_lines(), { [[r"{foo}"]] })
end

eq(get_lines(), { [[print(f"{foo}")]] })
T["f-string"]["skip on format"] = function()
child.cmd("e! _not_existing_new_buffer.py")
child.type_keys("cc", [["{foo}".format()]], "<Esc>", "0lll", "i", "<Esc>")
eq(get_lines(), { [["{foo}".format()]] })
end

-- Return test set which will be collected and execute inside `MiniTest.run()`
Expand Down