330 lines
7.6 KiB
Go
330 lines
7.6 KiB
Go
package autopatch
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/config"
|
|
"ai-code-assistant/pkg/database"
|
|
"ai-code-assistant/pkg/indexer"
|
|
"ai-code-assistant/pkg/llm"
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-git/go-billy/v5/osfs"
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing/cache"
|
|
"github.com/go-git/go-git/v5/storage/filesystem"
|
|
"github.com/sergi/go-diff/diffmatchpatch"
|
|
"github.com/urfave/cli/v3"
|
|
"log/slog"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
func Command() *cli.Command {
|
|
return &cli.Command{
|
|
Name: "auto-patch",
|
|
Usage: "this command accepts a repository and a prompt and will generate a git commit attempting code modifications to satisfy the prompt.",
|
|
Action: (&autoPatch{}).run,
|
|
Flags: []cli.Flag{
|
|
&cli.StringFlag{
|
|
Name: "repo",
|
|
Usage: "path to git repository",
|
|
Required: true,
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "task",
|
|
Usage: "task to perform, e.g. \"add a test for a function\"",
|
|
Required: true,
|
|
},
|
|
&cli.BoolFlag{
|
|
Name: "execute",
|
|
Usage: "actually execute the commit, otherwise just print the code to be committed",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
type autoPatch struct {
|
|
llm *llm.LLM
|
|
execute bool
|
|
}
|
|
|
|
func (a *autoPatch) run(ctx context.Context, cmd *cli.Command) error {
|
|
llmRef := llm.FromContext(ctx)
|
|
a.llm = llmRef
|
|
a.execute = cmd.Bool("execute")
|
|
|
|
// Make sure we're indexed.
|
|
idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false)
|
|
if err := idx.Index(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Attempt to generate the commit.
|
|
err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *autoPatch) generateGitCommit(ctx context.Context, repoPath, prompt string) error {
|
|
var affectedFiles []string
|
|
|
|
fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affectedFiles = append(affectedFiles, fileName)
|
|
|
|
// If we modified a test, we don't need to generate a test.
|
|
if !strings.HasSuffix(fileName, "_test.go") {
|
|
testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affectedFiles = append(affectedFiles, testFile)
|
|
}
|
|
|
|
if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *autoPatch) commit(ctx context.Context, prompt, repoPath string, files ...string) error {
|
|
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
|
|
|
|
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), osfs.New(repoPath))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
workTree, err := gitRepo.Worktree()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, file := range files {
|
|
// Relative paths.
|
|
file = strings.TrimPrefix(file, repoPath)
|
|
file = strings.TrimPrefix(file, "/")
|
|
|
|
if _, err := workTree.Add(file); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_commitmsg", map[string]any{
|
|
"Prompt": prompt,
|
|
"Files": files,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rsp, err := a.llm.ChatPrompt(ctx, genPrompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if a.execute {
|
|
if _, err := workTree.Commit(rsp, &git.CommitOptions{}); err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Info("committed changes to git repo", "repo", repoPath)
|
|
} else {
|
|
fmt.Printf("Commit Message:\n%s\n", rsp)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *autoPatch) generateCodePatch(ctx context.Context, repoPath, prompt string) (string, string, error) {
|
|
db := database.FromContext(ctx)
|
|
cfg := config.FromContext(ctx)
|
|
|
|
repoID, _, err := db.UpsertRepo(ctx, repoPath)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
relDocs := llm.NewGetRelevantDocs(db, a.llm, repoID, cfg.RelevantDocs)
|
|
chunks, err := relDocs.GetRelevantFileChunks(ctx, prompt)
|
|
if err != nil {
|
|
return "", "", err
|
|
} else if len(chunks) == 0 {
|
|
return "", "", errors.New("no relevant chunks found")
|
|
}
|
|
|
|
chunk := chunks[0]
|
|
|
|
slog.Info("found most relevant file chunk", "file", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID)
|
|
|
|
chunkContext, err := db.GetChunkContext(ctx, chunks[0].ChunkID, 1, chunks[0].Name, repoID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_patch", map[string]string{
|
|
"Prompt": prompt,
|
|
"Context": chunkContext,
|
|
})
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
codeBlock, err := a.generateCode(ctx, genPrompt)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
fileName := chunks[0].Name
|
|
originalFile, err := os.ReadFile(fileName)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
dmp := diffmatchpatch.New()
|
|
diffs := dmp.DiffMain(string(originalFile), codeBlock, true)
|
|
diffs = cleanDiffs(diffs)
|
|
|
|
if a.execute {
|
|
slog.Info("applying generated patch to file", "file", fileName)
|
|
if err := patchFile(fileName, diffs); err != nil {
|
|
return "", "", err
|
|
}
|
|
} else {
|
|
fmt.Printf("Code block:\n%s\n", codeBlock)
|
|
fmt.Printf("File to patch: %s\n", fileName)
|
|
fmt.Println(dmp.DiffPrettyText(diffs))
|
|
}
|
|
|
|
return fileName, codeBlock, err
|
|
}
|
|
|
|
func (a *autoPatch) generateUnitTest(ctx context.Context, prompt, fileName, newCode string) (string, error) {
|
|
// Check to see if a test file for this already exists.
|
|
testFileExists := false
|
|
|
|
testFile := strings.ReplaceAll(fileName, ".go", "_test.go")
|
|
if _, err := os.Stat(testFile); err == nil {
|
|
testFileExists = true
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_unittest", map[string]any{
|
|
"Prompt": prompt,
|
|
"Context": newCode,
|
|
"TestFileExists": testFileExists,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
codeBlock, err := a.generateCode(ctx, genPrompt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if a.execute {
|
|
slog.Info("applying generated unit test to file", "file", testFile)
|
|
|
|
if testFileExists {
|
|
fp, err := os.OpenFile(testFile, os.O_APPEND|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer fp.Close()
|
|
|
|
if _, err := fp.WriteString("\n" + codeBlock); err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
fp, err := os.Open(testFile)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer fp.Close()
|
|
|
|
if _, err := fp.WriteString(codeBlock); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
} else {
|
|
fmt.Printf("Unit Test Code block for %s:\n%s\n", testFile, codeBlock)
|
|
}
|
|
|
|
return testFile, nil
|
|
}
|
|
|
|
func (a *autoPatch) generateCode(ctx context.Context, prompt string) (string, error) {
|
|
rsp, err := a.llm.CodePrompt(ctx, prompt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
startIdx := strings.Index(rsp, "```")
|
|
endIdx := strings.LastIndex(rsp, "```")
|
|
|
|
if startIdx == -1 || endIdx == -1 || startIdx >= endIdx {
|
|
return "", fmt.Errorf("unable to find code block in response: %s", rsp)
|
|
}
|
|
|
|
codeBlock := rsp[startIdx+3 : endIdx]
|
|
if strings.HasPrefix(codeBlock, "go") {
|
|
codeBlock = codeBlock[2:]
|
|
}
|
|
|
|
return codeBlock, nil
|
|
}
|
|
|
|
func patchFile(fileName string, diffs []diffmatchpatch.Diff) error {
|
|
var buff bytes.Buffer
|
|
for _, diff := range diffs {
|
|
text := diff.Text
|
|
|
|
switch diff.Type {
|
|
case diffmatchpatch.DiffInsert:
|
|
_, _ = buff.WriteString(text)
|
|
case diffmatchpatch.DiffEqual:
|
|
_, _ = buff.WriteString(text)
|
|
}
|
|
}
|
|
|
|
if err := os.WriteFile(fileName, buff.Bytes(), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanDiffs will ignore any deletions at the beginning or end of the context since the LLM may trim these off.
|
|
func cleanDiffs(diffs []diffmatchpatch.Diff) []diffmatchpatch.Diff {
|
|
startIdx := 0
|
|
endIdx := len(diffs)
|
|
|
|
for idx, diff := range diffs {
|
|
if diff.Type == diffmatchpatch.DiffDelete {
|
|
startIdx = idx + 1
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
for idx := len(diffs) - 1; idx >= 0; idx-- {
|
|
if diffs[idx].Type == diffmatchpatch.DiffDelete {
|
|
endIdx = idx - 1
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
return diffs[startIdx:endIdx]
|
|
}
|