Files
ai-images/Source/AIImages/Services/StableDiffusionApiService.cs

286 lines
10 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using AIImages.Models;
using Newtonsoft.Json;
using Verse;
namespace AIImages.Services
{
/// <summary>
/// Сервис для работы с Stable Diffusion API (AUTOMATIC1111 WebUI)
/// </summary>
public class StableDiffusionApiService : IStableDiffusionApiService
{
private readonly HttpClient httpClient;
private readonly string saveFolderPath;
public StableDiffusionApiService(string savePath = "AIImages/Generated")
{
httpClient = new HttpClient { Timeout = TimeSpan.FromMinutes(5) };
// Определяем путь для сохранения
saveFolderPath = Path.Combine(GenFilePaths.SaveDataFolderPath, savePath);
// Создаем папку, если не существует
if (!Directory.Exists(saveFolderPath))
{
Directory.CreateDirectory(saveFolderPath);
}
}
public async Task<GenerationResult> GenerateImageAsync(GenerationRequest request)
{
try
{
Log.Message(
$"[AI Images] Starting image generation with prompt: {request.Prompt.Substring(0, Math.Min(50, request.Prompt.Length))}..."
);
// Формируем JSON запрос для AUTOMATIC1111 API
var apiRequest = new
{
prompt = request.Prompt,
negative_prompt = request.NegativePrompt,
steps = request.Steps,
cfg_scale = request.CfgScale,
width = request.Width,
height = request.Height,
sampler_name = request.Sampler,
scheduler = request.Scheduler,
seed = request.Seed,
save_images = false,
send_images = true,
};
string jsonRequest = JsonConvert.SerializeObject(apiRequest);
var content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
// Отправляем запрос
string endpoint = $"{request.Model}/sdapi/v1/txt2img";
HttpResponseMessage response = await httpClient.PostAsync(endpoint, content);
if (!response.IsSuccessStatusCode)
{
string errorContent = await response.Content.ReadAsStringAsync();
Log.Error(
$"[AI Images] API request failed: {response.StatusCode} - {errorContent}"
);
return GenerationResult.Failure($"API Error: {response.StatusCode}");
}
string jsonResponse = await response.Content.ReadAsStringAsync();
var apiResponse = JsonConvert.DeserializeObject<Txt2ImgResponse>(jsonResponse);
if (apiResponse?.images == null || apiResponse.images.Length == 0)
{
return GenerationResult.Failure("No images returned from API");
}
// Декодируем изображение из base64
byte[] imageData = Convert.FromBase64String(apiResponse.images[0]);
// Сохраняем изображение
string fileName = $"pawn_{DateTime.Now:yyyyMMdd_HHmmss}.png";
string fullPath = Path.Combine(saveFolderPath, fileName);
await File.WriteAllBytesAsync(fullPath, imageData);
Log.Message($"[AI Images] Image generated successfully and saved to: {fullPath}");
return GenerationResult.SuccessResult(imageData, fullPath, request);
}
catch (TaskCanceledException)
{
return GenerationResult.Failure("Request timeout. Generation took too long.");
}
catch (HttpRequestException ex)
{
Log.Error($"[AI Images] HTTP error: {ex.Message}");
return GenerationResult.Failure($"Connection error: {ex.Message}");
}
catch (Exception ex)
{
Log.Error($"[AI Images] Unexpected error: {ex.Message}\n{ex.StackTrace}");
return GenerationResult.Failure($"Error: {ex.Message}");
}
}
public async Task<bool> CheckApiAvailability(string apiEndpoint)
{
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
return response.IsSuccessStatusCode;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] API check failed: {ex.Message}");
return false;
}
}
public async Task<List<string>> GetAvailableModels(string apiEndpoint)
{
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
if (!response.IsSuccessStatusCode)
return new List<string>();
string jsonResponse = await response.Content.ReadAsStringAsync();
var models = JsonConvert.DeserializeObject<List<SdModel>>(jsonResponse);
var modelNames = new List<string>();
if (models != null)
{
foreach (var model in models)
{
modelNames.Add(model.title ?? model.model_name);
}
}
Log.Message($"[AI Images] Found {modelNames.Count} models");
return modelNames;
}
catch (Exception ex)
{
Log.Error($"[AI Images] Failed to load models: {ex.Message}");
return new List<string>();
}
}
public async Task<List<string>> GetAvailableSamplers(string apiEndpoint)
{
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/samplers";
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
if (!response.IsSuccessStatusCode)
return GetDefaultSamplers();
string jsonResponse = await response.Content.ReadAsStringAsync();
var samplers = JsonConvert.DeserializeObject<List<SdSampler>>(jsonResponse);
var samplerNames = new List<string>();
if (samplers != null)
{
foreach (var sampler in samplers)
{
samplerNames.Add(sampler.name);
}
}
Log.Message($"[AI Images] Found {samplerNames.Count} samplers");
return samplerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load samplers: {ex.Message}");
return GetDefaultSamplers();
}
}
public async Task<List<string>> GetAvailableSchedulers(string apiEndpoint)
{
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/schedulers";
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
if (!response.IsSuccessStatusCode)
return GetDefaultSchedulers();
string jsonResponse = await response.Content.ReadAsStringAsync();
var schedulers = JsonConvert.DeserializeObject<List<SdScheduler>>(jsonResponse);
var schedulerNames = new List<string>();
if (schedulers != null)
{
foreach (var scheduler in schedulers)
{
schedulerNames.Add(scheduler.name);
}
}
Log.Message($"[AI Images] Found {schedulerNames.Count} schedulers");
return schedulerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load schedulers: {ex.Message}");
return GetDefaultSchedulers();
}
}
private List<string> GetDefaultSamplers()
{
return new List<string>
{
"Euler a",
"Euler",
"LMS",
"Heun",
"DPM2",
"DPM2 a",
"DPM++ 2S a",
"DPM++ 2M",
"DPM++ SDE",
"DPM fast",
"DPM adaptive",
"LMS Karras",
"DPM2 Karras",
"DPM2 a Karras",
"DPM++ 2S a Karras",
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DDIM",
"PLMS",
};
}
private List<string> GetDefaultSchedulers()
{
return new List<string>
{
"Automatic",
"Uniform",
"Karras",
"Exponential",
"Polyexponential",
"SGM Uniform",
};
}
// Вспомогательные классы для десериализации JSON ответов
#pragma warning disable S3459, S1144 // Properties set by JSON deserializer
private sealed class Txt2ImgResponse
{
public string[] images { get; set; }
}
private sealed class SdModel
{
public string title { get; set; }
public string model_name { get; set; }
}
private sealed class SdSampler
{
public string name { get; set; }
}
private sealed class SdScheduler
{
public string name { get; set; }
}
#pragma warning restore S3459, S1144
}
}