Files
ai-images/Source/AIImages/Services/StableDiffusionNetAdapter.cs

309 lines
11 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using AIImages.Models;
using StableDiffusionNet;
using StableDiffusionNet.Exceptions;
using StableDiffusionNet.Interfaces;
using StableDiffusionNet.Models.Requests;
using Verse;
namespace AIImages.Services
{
/// <summary>
/// Адаптер для Stable Diffusion API (AUTOMATIC1111 WebUI)
/// Использует библиотеку StableDiffusionNet для работы с API
/// </summary>
public class StableDiffusionNetAdapter : IStableDiffusionApiService, IDisposable
{
private readonly IStableDiffusionClient _client;
private readonly string _saveFolderPath;
private bool _disposed;
public StableDiffusionNetAdapter(string apiEndpoint, string savePath = "AIImages/Generated")
{
if (string.IsNullOrEmpty(apiEndpoint))
{
throw new ArgumentException(
"API endpoint cannot be null or empty",
nameof(apiEndpoint)
);
}
// Определяем путь для сохранения
_saveFolderPath = Path.Combine(GenFilePaths.SaveDataFolderPath, savePath);
// Создаем папку, если не существует
if (!Directory.Exists(_saveFolderPath))
{
Directory.CreateDirectory(_saveFolderPath);
}
// Создаем клиент StableDiffusion используя Builder
_client = new StableDiffusionClientBuilder()
.WithBaseUrl(apiEndpoint)
.WithTimeout(300) // 5 минут в секундах
.WithRetry(retryCount: 3, retryDelayMilliseconds: 1000)
.Build();
Log.Message(
$"[AI Images] StableDiffusion adapter initialized with endpoint: {apiEndpoint}"
);
}
public async Task<GenerationResult> GenerateImageAsync(
GenerationRequest request,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
if (request == null)
{
return GenerationResult.Failure("Request cannot be null");
}
try
{
Log.Message(
$"[AI Images] Starting image generation with prompt: {request.Prompt.Substring(0, Math.Min(50, request.Prompt.Length))}..."
);
// Маппируем наш запрос на запрос библиотеки StableDiffusionNet
var sdRequest = new TextToImageRequest
{
Prompt = request.Prompt,
NegativePrompt = request.NegativePrompt,
Steps = request.Steps,
CfgScale = request.CfgScale,
Width = request.Width,
Height = request.Height,
SamplerName = request.Sampler,
Scheduler = request.Scheduler,
Seed = request.Seed,
// SaveImages и SendImages не нужны - библиотека всегда возвращает изображения
};
// Выполняем запрос через библиотеку (с встроенной retry логикой)
var response = await _client.TextToImage.GenerateAsync(
sdRequest,
cancellationToken
);
if (response?.Images == null || response.Images.Count == 0)
{
return GenerationResult.Failure("No images returned from API");
}
// Декодируем изображение из base64
byte[] imageData = Convert.FromBase64String(response.Images[0]);
// Сохраняем изображение
string fileName = $"pawn_{DateTime.Now:yyyyMMdd_HHmmss}.png";
string fullPath = Path.Combine(_saveFolderPath, fileName);
await File.WriteAllBytesAsync(fullPath, imageData);
Log.Message($"[AI Images] Image generated successfully and saved to: {fullPath}");
return GenerationResult.SuccessResult(imageData, fullPath, request);
}
catch (TaskCanceledException)
{
Log.Warning("[AI Images] Request timeout. Generation took too long.");
return GenerationResult.Failure("Request timeout. Generation took too long.");
}
catch (OperationCanceledException)
{
Log.Message("[AI Images] Image generation was cancelled");
return GenerationResult.Failure("Generation cancelled");
}
catch (HttpRequestException ex)
{
Log.Error($"[AI Images] HTTP error: {ex.Message}");
return GenerationResult.Failure($"Connection error: {ex.Message}");
}
catch (StableDiffusionException ex)
{
Log.Error($"[AI Images] StableDiffusion API error: {ex.Message}");
return GenerationResult.Failure($"API Error: {ex.Message}");
}
catch (Exception ex)
{
Log.Error($"[AI Images] Unexpected error: {ex.Message}\n{ex.StackTrace}");
return GenerationResult.Failure($"Error: {ex.Message}");
}
}
public async Task<bool> CheckApiAvailability(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
// Используем встроенный метод PingAsync библиотеки
return await _client.PingAsync(cancellationToken);
}
catch (Exception ex)
{
Log.Warning($"[AI Images] API check failed: {ex.Message}");
return false;
}
}
public async Task<List<string>> GetAvailableModels(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
// Используем Models сервис библиотеки
var models = await _client.Models.GetModelsAsync(cancellationToken);
var modelNames = new List<string>();
if (models != null)
{
foreach (var model in models)
{
// Используем Title или ModelName в зависимости от того, что доступно
modelNames.Add(model.Title ?? model.ModelName);
}
}
Log.Message($"[AI Images] Found {modelNames.Count} models");
return modelNames;
}
catch (Exception ex)
{
Log.Error($"[AI Images] Failed to load models: {ex.Message}");
return new List<string>();
}
}
public async Task<List<string>> GetAvailableSamplers(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
// Используем Samplers сервис библиотеки
var samplers = await _client.Samplers.GetSamplersAsync(cancellationToken);
var samplerNames = new List<string>();
if (samplers != null)
{
samplerNames.AddRange(samplers);
}
Log.Message($"[AI Images] Found {samplerNames.Count} samplers");
return samplerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load samplers: {ex.Message}");
return GetDefaultSamplers();
}
}
public async Task<List<string>> GetAvailableSchedulers(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
// Используем Schedulers сервис библиотеки (доступен с версии 1.1.1)
var schedulers = await _client.Schedulers.GetSchedulersAsync(cancellationToken);
var schedulerNames = new List<string>();
if (schedulers != null)
{
schedulerNames.AddRange(schedulers);
}
Log.Message($"[AI Images] Found {schedulerNames.Count} schedulers");
return schedulerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load schedulers: {ex.Message}");
return GetDefaultSchedulers();
}
}
private List<string> GetDefaultSamplers()
{
return new List<string>
{
"Euler a",
"Euler",
"LMS",
"Heun",
"DPM2",
"DPM2 a",
"DPM++ 2S a",
"DPM++ 2M",
"DPM++ SDE",
"DPM fast",
"DPM adaptive",
"LMS Karras",
"DPM2 Karras",
"DPM2 a Karras",
"DPM++ 2S a Karras",
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DDIM",
"PLMS",
};
}
private List<string> GetDefaultSchedulers()
{
return new List<string>
{
"Automatic",
"Uniform",
"Karras",
"Exponential",
"Polyexponential",
"SGM Uniform",
};
}
private void ThrowIfDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(StableDiffusionNetAdapter));
}
}
public void Dispose()
{
if (_disposed)
return;
// Dispose клиента StableDiffusion
if (_client is IDisposable disposableClient)
{
disposableClient.Dispose();
}
_disposed = true;
}
}
}