Enhance AIImages mod by adding cancellation support for image generation, improving user experience with localized strings for cancellation actions in English and Russian. Refactor service integration for better dependency management and update AIImages.dll to reflect these changes.

This commit is contained in:
Leonid Pershin
2025-10-26 19:10:45 +03:00
parent 3434927342
commit 02b0143186
11 changed files with 974 additions and 174 deletions

View File

@@ -0,0 +1,377 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using AIImages.Models;
using Newtonsoft.Json;
using Verse;
namespace AIImages.Services
{
/// <summary>
/// Адаптер для Stable Diffusion API (AUTOMATIC1111 WebUI)
/// TODO: В будущем можно мигрировать на библиотеку StableDiffusionNet когда API будет полностью совместимо
/// </summary>
public class StableDiffusionNetAdapter : IStableDiffusionApiService, IDisposable
{
// Shared HttpClient для предотвращения socket exhaustion
// См: https://learn.microsoft.com/en-us/dotnet/fundamentals/networking/http/httpclient-guidelines
private static readonly HttpClient _sharedHttpClient = new HttpClient
{
Timeout = TimeSpan.FromMinutes(5),
};
private readonly string _apiEndpoint;
private readonly string _saveFolderPath;
private bool _disposed;
public StableDiffusionNetAdapter(string apiEndpoint, string savePath = "AIImages/Generated")
{
if (string.IsNullOrEmpty(apiEndpoint))
{
throw new ArgumentException(
"API endpoint cannot be null or empty",
nameof(apiEndpoint)
);
}
_apiEndpoint = apiEndpoint;
// Определяем путь для сохранения
_saveFolderPath = Path.Combine(GenFilePaths.SaveDataFolderPath, savePath);
// Создаем папку, если не существует
if (!Directory.Exists(_saveFolderPath))
{
Directory.CreateDirectory(_saveFolderPath);
}
Log.Message(
$"[AI Images] StableDiffusion adapter initialized with endpoint: {apiEndpoint}"
);
}
public async Task<GenerationResult> GenerateImageAsync(
GenerationRequest request,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
if (request == null)
{
return GenerationResult.Failure("Request cannot be null");
}
try
{
Log.Message(
$"[AI Images] Starting image generation with prompt: {request.Prompt.Substring(0, Math.Min(50, request.Prompt.Length))}..."
);
// Формируем JSON запрос для AUTOMATIC1111 API
var apiRequest = new
{
prompt = request.Prompt,
negative_prompt = request.NegativePrompt,
steps = request.Steps,
cfg_scale = request.CfgScale,
width = request.Width,
height = request.Height,
sampler_name = request.Sampler,
scheduler = request.Scheduler,
seed = request.Seed,
save_images = false,
send_images = true,
};
string jsonRequest = JsonConvert.SerializeObject(apiRequest);
var content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
// Отправляем запрос с поддержкой cancellation
string endpoint = $"{_apiEndpoint}/sdapi/v1/txt2img";
HttpResponseMessage response = await _sharedHttpClient.PostAsync(
endpoint,
content,
cancellationToken
);
if (!response.IsSuccessStatusCode)
{
string errorContent = await response.Content.ReadAsStringAsync();
Log.Error(
$"[AI Images] API request failed: {response.StatusCode} - {errorContent}"
);
return GenerationResult.Failure($"API Error: {response.StatusCode}");
}
string jsonResponse = await response.Content.ReadAsStringAsync();
var apiResponse = JsonConvert.DeserializeObject<Txt2ImgResponse>(jsonResponse);
if (apiResponse?.images == null || apiResponse.images.Length == 0)
{
return GenerationResult.Failure("No images returned from API");
}
// Декодируем изображение из base64
byte[] imageData = Convert.FromBase64String(apiResponse.images[0]);
// Сохраняем изображение
string fileName = $"pawn_{DateTime.Now:yyyyMMdd_HHmmss}.png";
string fullPath = Path.Combine(_saveFolderPath, fileName);
await File.WriteAllBytesAsync(fullPath, imageData);
Log.Message($"[AI Images] Image generated successfully and saved to: {fullPath}");
return GenerationResult.SuccessResult(imageData, fullPath, request);
}
catch (TaskCanceledException)
{
Log.Warning("[AI Images] Request timeout. Generation took too long.");
return GenerationResult.Failure("Request timeout. Generation took too long.");
}
catch (OperationCanceledException)
{
Log.Message("[AI Images] Image generation was cancelled");
return GenerationResult.Failure("Generation cancelled");
}
catch (HttpRequestException ex)
{
Log.Error($"[AI Images] HTTP error: {ex.Message}");
return GenerationResult.Failure($"Connection error: {ex.Message}");
}
catch (Exception ex)
{
Log.Error($"[AI Images] Unexpected error: {ex.Message}\n{ex.StackTrace}");
return GenerationResult.Failure($"Error: {ex.Message}");
}
}
public async Task<bool> CheckApiAvailability(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
endpoint,
cancellationToken
);
return response.IsSuccessStatusCode;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] API check failed: {ex.Message}");
return false;
}
}
public async Task<List<string>> GetAvailableModels(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
endpoint,
cancellationToken
);
if (!response.IsSuccessStatusCode)
return new List<string>();
string jsonResponse = await response.Content.ReadAsStringAsync();
var models = JsonConvert.DeserializeObject<List<SdModel>>(jsonResponse);
var modelNames = new List<string>();
if (models != null)
{
foreach (var model in models)
{
modelNames.Add(model.title ?? model.model_name);
}
}
Log.Message($"[AI Images] Found {modelNames.Count} models");
return modelNames;
}
catch (Exception ex)
{
Log.Error($"[AI Images] Failed to load models: {ex.Message}");
return new List<string>();
}
}
public async Task<List<string>> GetAvailableSamplers(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/samplers";
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
endpoint,
cancellationToken
);
if (!response.IsSuccessStatusCode)
return GetDefaultSamplers();
string jsonResponse = await response.Content.ReadAsStringAsync();
var samplers = JsonConvert.DeserializeObject<List<SdSampler>>(jsonResponse);
var samplerNames = new List<string>();
if (samplers != null)
{
foreach (var sampler in samplers)
{
samplerNames.Add(sampler.name);
}
}
Log.Message($"[AI Images] Found {samplerNames.Count} samplers");
return samplerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load samplers: {ex.Message}");
return GetDefaultSamplers();
}
}
public async Task<List<string>> GetAvailableSchedulers(
string apiEndpoint,
CancellationToken cancellationToken = default
)
{
ThrowIfDisposed();
try
{
string endpoint = $"{apiEndpoint}/sdapi/v1/schedulers";
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
endpoint,
cancellationToken
);
if (!response.IsSuccessStatusCode)
return GetDefaultSchedulers();
string jsonResponse = await response.Content.ReadAsStringAsync();
var schedulers = JsonConvert.DeserializeObject<List<SdScheduler>>(jsonResponse);
var schedulerNames = new List<string>();
if (schedulers != null)
{
foreach (var scheduler in schedulers)
{
schedulerNames.Add(scheduler.name);
}
}
Log.Message($"[AI Images] Found {schedulerNames.Count} schedulers");
return schedulerNames;
}
catch (Exception ex)
{
Log.Warning($"[AI Images] Failed to load schedulers: {ex.Message}");
return GetDefaultSchedulers();
}
}
private List<string> GetDefaultSamplers()
{
return new List<string>
{
"Euler a",
"Euler",
"LMS",
"Heun",
"DPM2",
"DPM2 a",
"DPM++ 2S a",
"DPM++ 2M",
"DPM++ SDE",
"DPM fast",
"DPM adaptive",
"LMS Karras",
"DPM2 Karras",
"DPM2 a Karras",
"DPM++ 2S a Karras",
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DDIM",
"PLMS",
};
}
private List<string> GetDefaultSchedulers()
{
return new List<string>
{
"Automatic",
"Uniform",
"Karras",
"Exponential",
"Polyexponential",
"SGM Uniform",
};
}
private void ThrowIfDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(StableDiffusionNetAdapter));
}
}
public void Dispose()
{
if (_disposed)
return;
// Не dispose shared HttpClient - он используется глобально
_disposed = true;
}
// Вспомогательные классы для десериализации JSON ответов
#pragma warning disable S3459, S1144, IDE1006 // Properties set by JSON deserializer
private sealed class Txt2ImgResponse
{
public string[] images { get; set; }
}
private sealed class SdModel
{
public string title { get; set; }
public string model_name { get; set; }
}
private sealed class SdSampler
{
public string name { get; set; }
}
private sealed class SdScheduler
{
public string name { get; set; }
}
#pragma warning restore S3459, S1144, IDE1006
}
}