Files
ai-code-assistant/cmd/autopatch/autopatch.go
Michael Powers 25f8cae8cb Code Cleanup and Quality of Life
Checks to make sure repo is indexed before generating code.
Don't generate tests for changes to tests.
Remove unused code.
Fix bootstrapping issue with langchaingo tables.
2025-04-20 08:31:26 -04:00

314 lines
7.0 KiB
Go

package autopatch
import (
"ai-code-assistant/pkg/config"
"ai-code-assistant/pkg/database"
"ai-code-assistant/pkg/indexer"
"ai-code-assistant/pkg/llm"
"bytes"
"context"
"errors"
"fmt"
"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/cache"
"github.com/go-git/go-git/v5/storage/filesystem"
"github.com/sergi/go-diff/diffmatchpatch"
"github.com/urfave/cli/v3"
"log/slog"
"os"
"path/filepath"
"strings"
)
func Command() *cli.Command {
return &cli.Command{
Name: "auto-patch",
Usage: "this command accepts a repository and a prompt and will generate a git commit attempting code modifications to satisfy the prompt.",
Action: (&agent{}).run,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "repo",
Usage: "path to git repository",
Required: true,
},
&cli.StringFlag{
Name: "task",
Usage: "task to perform, e.g. \"add a test for a function\"",
Required: true,
},
},
}
}
type agent struct {
llm *llm.LLM
}
func (a *agent) run(ctx context.Context, cmd *cli.Command) error {
llmRef := llm.FromContext(ctx)
a.llm = llmRef
// Make sure we're indexed.
idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false)
if err := idx.Index(ctx); err != nil {
return err
}
// Attempt to generate the commit.
err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task"))
if err != nil {
return err
}
return nil
}
func (a *agent) generateGitCommit(ctx context.Context, repoPath, prompt string) error {
var affectedFiles []string
fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt)
if err != nil {
return err
}
affectedFiles = append(affectedFiles, fileName)
// If we modified a test, we don't need to generate a test.
if !strings.HasSuffix(fileName, "_test.go") {
testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode)
if err != nil {
return err
}
affectedFiles = append(affectedFiles, testFile)
}
if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil {
return err
}
slog.Info("committed changes to git repo", "repo", repoPath)
return nil
}
func (a *agent) commit(ctx context.Context, prompt, repoPath string, files ...string) error {
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), osfs.New(repoPath))
if err != nil {
return err
}
workTree, err := gitRepo.Worktree()
if err != nil {
return err
}
for _, file := range files {
// Relative paths.
file = strings.TrimPrefix(file, repoPath)
file = strings.TrimPrefix(file, "/")
if _, err := workTree.Add(file); err != nil {
return err
}
}
genPrompt, err := llm.GetPrompt("generate_commitmsg", map[string]any{
"Prompt": prompt,
"Files": files,
})
if err != nil {
return err
}
rsp, err := a.llm.ChatPrompt(ctx, genPrompt)
if err != nil {
return err
}
if _, err := workTree.Commit(rsp, &git.CommitOptions{}); err != nil {
return err
}
return nil
}
func (a *agent) generateCodePatch(ctx context.Context, repoPath, prompt string) (string, string, error) {
db := database.FromContext(ctx)
cfg := config.FromContext(ctx)
repoID, _, err := db.UpsertRepo(ctx, repoPath)
if err != nil {
return "", "", err
}
relDocs := llm.NewGetRelevantDocs(db, a.llm, repoID, cfg.RelevantDocs)
chunks, err := relDocs.GetRelevantFileChunks(ctx, prompt)
if err != nil {
return "", "", err
} else if len(chunks) == 0 {
return "", "", errors.New("no relevant chunks found")
}
chunk := chunks[0]
slog.Info("found most relevant file chunk", "file", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID)
chunkContext, err := db.GetChunkContext(ctx, chunks[0].ChunkID, 1, chunks[0].Name, repoID)
if err != nil {
return "", "", err
}
genPrompt, err := llm.GetPrompt("generate_patch", map[string]string{
"Prompt": prompt,
"Context": chunkContext,
})
if err != nil {
return "", "", err
}
codeBlock, err := a.generateCode(ctx, genPrompt)
if err != nil {
return "", "", err
}
fmt.Printf("Code block:\n%s\n", codeBlock)
fileName := chunks[0].Name
originalFile, err := os.ReadFile(fileName)
if err != nil {
return "", "", err
}
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(string(originalFile), codeBlock, true)
diffs = cleanDiffs(diffs)
fmt.Printf("File to patch: %s\n", fileName)
fmt.Println(dmp.DiffPrettyText(diffs))
if err := patchFile(fileName, diffs); err != nil {
return "", "", err
}
return fileName, codeBlock, err
}
func (a *agent) generateUnitTest(ctx context.Context, prompt, fileName, newCode string) (string, error) {
// Check to see if a test file for this already exists.
testFileExists := false
testFile := strings.ReplaceAll(fileName, ".go", "_test.go")
if _, err := os.Stat(testFile); err == nil {
testFileExists = true
}
genPrompt, err := llm.GetPrompt("generate_unittest", map[string]any{
"Prompt": prompt,
"Context": newCode,
"TestFileExists": testFileExists,
})
if err != nil {
return "", err
}
codeBlock, err := a.generateCode(ctx, genPrompt)
if err != nil {
return "", err
}
fmt.Printf("Unit Test Code block:\n%s\n", codeBlock)
if testFileExists {
fp, err := os.OpenFile(testFile, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return "", err
}
defer fp.Close()
if _, err := fp.WriteString("\n" + codeBlock); err != nil {
return "", err
}
} else {
fp, err := os.Open(testFile)
if err != nil {
return "", err
}
defer fp.Close()
if _, err := fp.WriteString(codeBlock); err != nil {
return "", err
}
}
return testFile, nil
}
func (a *agent) generateCode(ctx context.Context, prompt string) (string, error) {
rsp, err := a.llm.CodePrompt(ctx, prompt)
if err != nil {
return "", err
}
startIdx := strings.Index(rsp, "```")
endIdx := strings.LastIndex(rsp, "```")
if startIdx == -1 || endIdx == -1 || startIdx >= endIdx {
return "", fmt.Errorf("unable to find code block in response: %s", rsp)
}
codeBlock := rsp[startIdx+3 : endIdx]
if strings.HasPrefix(codeBlock, "go") {
codeBlock = codeBlock[2:]
}
return codeBlock, nil
}
func patchFile(fileName string, diffs []diffmatchpatch.Diff) error {
var buff bytes.Buffer
for _, diff := range diffs {
text := diff.Text
switch diff.Type {
case diffmatchpatch.DiffInsert:
_, _ = buff.WriteString(text)
case diffmatchpatch.DiffEqual:
_, _ = buff.WriteString(text)
}
}
if err := os.WriteFile(fileName, buff.Bytes(), 0644); err != nil {
return err
}
return nil
}
// cleanDiffs will ignore any deletions at the beginning or end of the context since the LLM may trim these off.
func cleanDiffs(diffs []diffmatchpatch.Diff) []diffmatchpatch.Diff {
startIdx := 0
endIdx := len(diffs)
for idx, diff := range diffs {
if diff.Type == diffmatchpatch.DiffDelete {
startIdx = idx + 1
} else {
break
}
}
for idx := len(diffs) - 1; idx >= 0; idx-- {
if diffs[idx].Type == diffmatchpatch.DiffDelete {
endIdx = idx - 1
} else {
break
}
}
return diffs[startIdx:endIdx]
}