Checks to make sure repo is indexed before generating code. Don't generate tests for changes to tests. Remove unused code. Fix bootstrapping issue with langchaingo tables.
314 lines
7.0 KiB
Go
314 lines
7.0 KiB
Go
package autopatch
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/config"
|
|
"ai-code-assistant/pkg/database"
|
|
"ai-code-assistant/pkg/indexer"
|
|
"ai-code-assistant/pkg/llm"
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-git/go-billy/v5/osfs"
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing/cache"
|
|
"github.com/go-git/go-git/v5/storage/filesystem"
|
|
"github.com/sergi/go-diff/diffmatchpatch"
|
|
"github.com/urfave/cli/v3"
|
|
"log/slog"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
)
|
|
|
|
func Command() *cli.Command {
|
|
return &cli.Command{
|
|
Name: "auto-patch",
|
|
Usage: "this command accepts a repository and a prompt and will generate a git commit attempting code modifications to satisfy the prompt.",
|
|
Action: (&agent{}).run,
|
|
Flags: []cli.Flag{
|
|
&cli.StringFlag{
|
|
Name: "repo",
|
|
Usage: "path to git repository",
|
|
Required: true,
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "task",
|
|
Usage: "task to perform, e.g. \"add a test for a function\"",
|
|
Required: true,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
type agent struct {
|
|
llm *llm.LLM
|
|
}
|
|
|
|
func (a *agent) run(ctx context.Context, cmd *cli.Command) error {
|
|
llmRef := llm.FromContext(ctx)
|
|
a.llm = llmRef
|
|
|
|
// Make sure we're indexed.
|
|
idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false)
|
|
if err := idx.Index(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Attempt to generate the commit.
|
|
err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) generateGitCommit(ctx context.Context, repoPath, prompt string) error {
|
|
var affectedFiles []string
|
|
|
|
fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affectedFiles = append(affectedFiles, fileName)
|
|
|
|
// If we modified a test, we don't need to generate a test.
|
|
if !strings.HasSuffix(fileName, "_test.go") {
|
|
testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affectedFiles = append(affectedFiles, testFile)
|
|
}
|
|
|
|
if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Info("committed changes to git repo", "repo", repoPath)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) commit(ctx context.Context, prompt, repoPath string, files ...string) error {
|
|
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
|
|
|
|
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), osfs.New(repoPath))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
workTree, err := gitRepo.Worktree()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, file := range files {
|
|
// Relative paths.
|
|
file = strings.TrimPrefix(file, repoPath)
|
|
file = strings.TrimPrefix(file, "/")
|
|
|
|
if _, err := workTree.Add(file); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_commitmsg", map[string]any{
|
|
"Prompt": prompt,
|
|
"Files": files,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rsp, err := a.llm.ChatPrompt(ctx, genPrompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := workTree.Commit(rsp, &git.CommitOptions{}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) generateCodePatch(ctx context.Context, repoPath, prompt string) (string, string, error) {
|
|
db := database.FromContext(ctx)
|
|
cfg := config.FromContext(ctx)
|
|
|
|
repoID, _, err := db.UpsertRepo(ctx, repoPath)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
relDocs := llm.NewGetRelevantDocs(db, a.llm, repoID, cfg.RelevantDocs)
|
|
chunks, err := relDocs.GetRelevantFileChunks(ctx, prompt)
|
|
if err != nil {
|
|
return "", "", err
|
|
} else if len(chunks) == 0 {
|
|
return "", "", errors.New("no relevant chunks found")
|
|
}
|
|
|
|
chunk := chunks[0]
|
|
|
|
slog.Info("found most relevant file chunk", "file", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID)
|
|
|
|
chunkContext, err := db.GetChunkContext(ctx, chunks[0].ChunkID, 1, chunks[0].Name, repoID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_patch", map[string]string{
|
|
"Prompt": prompt,
|
|
"Context": chunkContext,
|
|
})
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
codeBlock, err := a.generateCode(ctx, genPrompt)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
fmt.Printf("Code block:\n%s\n", codeBlock)
|
|
|
|
fileName := chunks[0].Name
|
|
originalFile, err := os.ReadFile(fileName)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
dmp := diffmatchpatch.New()
|
|
diffs := dmp.DiffMain(string(originalFile), codeBlock, true)
|
|
diffs = cleanDiffs(diffs)
|
|
|
|
fmt.Printf("File to patch: %s\n", fileName)
|
|
fmt.Println(dmp.DiffPrettyText(diffs))
|
|
|
|
if err := patchFile(fileName, diffs); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return fileName, codeBlock, err
|
|
}
|
|
|
|
func (a *agent) generateUnitTest(ctx context.Context, prompt, fileName, newCode string) (string, error) {
|
|
// Check to see if a test file for this already exists.
|
|
testFileExists := false
|
|
|
|
testFile := strings.ReplaceAll(fileName, ".go", "_test.go")
|
|
if _, err := os.Stat(testFile); err == nil {
|
|
testFileExists = true
|
|
}
|
|
|
|
genPrompt, err := llm.GetPrompt("generate_unittest", map[string]any{
|
|
"Prompt": prompt,
|
|
"Context": newCode,
|
|
"TestFileExists": testFileExists,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
codeBlock, err := a.generateCode(ctx, genPrompt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
fmt.Printf("Unit Test Code block:\n%s\n", codeBlock)
|
|
|
|
if testFileExists {
|
|
fp, err := os.OpenFile(testFile, os.O_APPEND|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer fp.Close()
|
|
|
|
if _, err := fp.WriteString("\n" + codeBlock); err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
fp, err := os.Open(testFile)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer fp.Close()
|
|
|
|
if _, err := fp.WriteString(codeBlock); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
return testFile, nil
|
|
}
|
|
|
|
func (a *agent) generateCode(ctx context.Context, prompt string) (string, error) {
|
|
rsp, err := a.llm.CodePrompt(ctx, prompt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
startIdx := strings.Index(rsp, "```")
|
|
endIdx := strings.LastIndex(rsp, "```")
|
|
|
|
if startIdx == -1 || endIdx == -1 || startIdx >= endIdx {
|
|
return "", fmt.Errorf("unable to find code block in response: %s", rsp)
|
|
}
|
|
|
|
codeBlock := rsp[startIdx+3 : endIdx]
|
|
if strings.HasPrefix(codeBlock, "go") {
|
|
codeBlock = codeBlock[2:]
|
|
}
|
|
|
|
return codeBlock, nil
|
|
}
|
|
|
|
func patchFile(fileName string, diffs []diffmatchpatch.Diff) error {
|
|
var buff bytes.Buffer
|
|
for _, diff := range diffs {
|
|
text := diff.Text
|
|
|
|
switch diff.Type {
|
|
case diffmatchpatch.DiffInsert:
|
|
_, _ = buff.WriteString(text)
|
|
case diffmatchpatch.DiffEqual:
|
|
_, _ = buff.WriteString(text)
|
|
}
|
|
}
|
|
|
|
if err := os.WriteFile(fileName, buff.Bytes(), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cleanDiffs will ignore any deletions at the beginning or end of the context since the LLM may trim these off.
|
|
func cleanDiffs(diffs []diffmatchpatch.Diff) []diffmatchpatch.Diff {
|
|
startIdx := 0
|
|
endIdx := len(diffs)
|
|
|
|
for idx, diff := range diffs {
|
|
if diff.Type == diffmatchpatch.DiffDelete {
|
|
startIdx = idx + 1
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
for idx := len(diffs) - 1; idx >= 0; idx-- {
|
|
if diffs[idx].Type == diffmatchpatch.DiffDelete {
|
|
endIdx = idx - 1
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
return diffs[startIdx:endIdx]
|
|
}
|