uni/WEB43-diary/internal/database/database.go

203 lines
5.1 KiB
Go

package database
import (
"database/sql"
"fmt"
"log"
"reflect"
"sort"
"strings"
"time"
_ "github.com/go-sql-driver/mysql" // Import for registration
"gitea.zokki.net/zokki/uni/web43-diary/context"
"gitea.zokki.net/zokki/uni/web43-diary/internal/config"
)
type SQLTable interface {
LoadForeignValues(*context.Context) error
}
type SQLPrimary interface {
GetPrimaryKeys() []string
}
var tablesToCreate = []*sqlTable{}
func NewDB() (*sql.DB, error) {
db, err := sql.Open("mysql", config.Config.DatabaseDsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
db.SetConnMaxLifetime(time.Minute * 3)
db.SetConnMaxIdleTime(time.Minute * 3)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
return db, nil
}
func AddCreateTableQueue(table any) {
tablesToCreate = append(tablesToCreate, getSQLTableFromInterface(table))
}
func CreateTablesFromQueue() {
db, err := NewDB()
if err != nil {
log.Fatal("[CreateTable] connect db: ", err)
}
defer db.Close()
sort.Sort(ByTableName(tablesToCreate))
for _, table := range tablesToCreate {
sqlStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (", table.Name)
sqlForeignKeys := ""
for _, column := range table.Columns {
sqlStatement += fmt.Sprintf("%s %s, ", column.Name, column.Statement)
if column.ForeignKey != "" {
sqlForeignKeys += fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s, ", column.Name, column.ForeignKey)
}
}
if _, err := db.Exec(strings.TrimSuffix(sqlStatement+sqlForeignKeys+table.PrimaryKey, ", ") + ")"); err != nil {
log.Fatal(fmt.Sprintf("[CreateTable] create table `%s`: ", table.Name), err)
}
}
}
func GetOne[TableVal SQLTable](ctx *context.Context, val TableVal) (TableVal, error) {
return GetOneWhere(ctx, QueryBuilderFromInterface(val))
}
func GetOneWhere[TableVal SQLTable](ctx *context.Context, queryBuilder *queryBuilder[TableVal]) (TableVal, error) {
var empty TableVal
sqlQuery, sqlValues := queryBuilder.Limit(1).BuildSelect()
sqlRows, err := ctx.DB.QueryContext(ctx, sqlQuery, sqlValues...)
if err != nil {
return empty, err
}
defer sqlRows.Close()
if !sqlRows.Next() {
return empty, fmt.Errorf("no data found for table '%s'", queryBuilder.table.Name)
}
columns, err := sqlRows.Columns()
if err != nil {
return empty, err
}
values := make([]any, len(columns))
columnPointers := make([]any, len(columns))
for i := range values {
columnPointers[i] = &values[i]
}
if err := sqlRows.Scan(columnPointers...); err != nil {
return empty, err
}
val := queryBuilder.object
deserialize(val, columns, values)
sqlRows.Close() // needs to be closed before using db again
err = val.LoadForeignValues(ctx)
if err != nil {
return empty, err
}
return val, nil
}
func GetAll[TableVal SQLTable](ctx *context.Context, val TableVal) ([]TableVal, error) {
return GetAllWhere(ctx, QueryBuilderFromInterface(val))
}
func GetAllWhere[TableVal SQLTable](ctx *context.Context, queryBuilder *queryBuilder[TableVal]) ([]TableVal, error) {
sqlQuery, sqlValues := queryBuilder.BuildSelect()
sqlRows, err := ctx.DB.QueryContext(ctx, sqlQuery, sqlValues...)
if err != nil {
log.Println("err", err, sqlQuery, sqlValues)
return nil, err
}
defer sqlRows.Close()
columns, err := sqlRows.Columns()
if err != nil {
return nil, err
}
values := make([]any, len(columns))
columnPointers := make([]any, len(columns))
for i := range values {
columnPointers[i] = &values[i]
}
tableRows := []TableVal{}
reflectType := reflect.TypeOf(queryBuilder.object).Elem()
for sqlRows.Next() {
if err := sqlRows.Scan(columnPointers...); err != nil {
return nil, err
}
tableVal := reflect.New(reflectType).Interface().(TableVal)
deserialize(tableVal, columns, values)
tableRows = append(tableRows, tableVal)
}
sqlRows.Close() // needs to be closed before using db again
for _, row := range tableRows {
err = row.LoadForeignValues(ctx)
if err != nil {
return nil, err
}
}
return tableRows, nil
}
func InsertInto(ctx *context.Context, table any) (int64, error) {
sqlTableName := getSQLTableName(table)
columns, tableValues := serialize(table)
if len(columns) < 1 {
return 0, nil
}
res, err := ctx.DB.ExecContext(ctx, fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s ?)", sqlTableName, strings.Join(columns, ", "), strings.Repeat("?, ", len(tableValues)-1)), tableValues...)
if err != nil {
return 0, err
}
return res.LastInsertId()
}
func Update[TableVal SQLTable](ctx *context.Context, searchTable TableVal, newTable TableVal) error {
newTableColumns, newTableValues := serialize(newTable)
if len(newTableColumns) < 1 {
return nil
}
queryBuilder := QueryBuilderFromInterface(searchTable)
where, whereValues := queryBuilder.where.Build(true)
queryValues := append(newTableValues, whereValues...)
_, err := ctx.DB.ExecContext(ctx, fmt.Sprintf("UPDATE `%s` SET %s = ? %s", queryBuilder.table.Name, strings.Join(newTableColumns, " = ?, "), where), queryValues...)
return err
}
func Delete(ctx *context.Context, table SQLTable) error {
sqlQuery, sqlValues := QueryBuilderFromInterface(table).BuildDelete()
_, err := ctx.DB.ExecContext(ctx, sqlQuery, sqlValues...)
return err
}